mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b93591ec32 | ||
|
|
0abfd8da0d | ||
|
|
a9cc4e36c6 | ||
|
|
fe1c963eec | ||
|
|
111d80e88d | ||
|
|
6718862dbe | ||
|
|
0fe1bf8a61 | ||
|
|
10f326eda9 | ||
|
|
cd0d6c1a3d | ||
|
|
3205f2df97 | ||
|
|
5bdbcfcd8d | ||
|
|
a2e2052b30 | ||
|
|
0146ded4f4 | ||
|
|
dccf9dd8f8 | ||
|
|
7816b402bb | ||
|
|
cd4ce30f7c | ||
|
|
8c7e230898 | ||
|
|
42ba696518 | ||
|
|
3f84e60a1f | ||
|
|
baba8b5b73 | ||
|
|
77397c4f21 | ||
|
|
8678091d8f | ||
|
|
aa22170ab4 | ||
|
|
901ec37290 | ||
|
|
21f2ea8b17 | ||
|
|
8219e3d4e2 | ||
|
|
3ed71a61d5 | ||
|
|
18a88a8e8f | ||
|
|
318a72987c | ||
|
|
5ce202cc99 | ||
|
|
d09528bc26 | ||
|
|
42d2a41dbe | ||
|
|
82be1840b0 | ||
|
|
27352c5cb6 | ||
|
|
1ea6408d41 | ||
|
|
5e095af3aa | ||
|
|
ab3dceed92 | ||
|
|
3bf5126d84 | ||
|
|
ab2ab7b23a | ||
|
|
c9184d125b | ||
|
|
0c0fdb72b9 | ||
|
|
86378053d4 | ||
|
|
b1cbba0cf1 | ||
|
|
f31526042d | ||
|
|
3f8d5bc346 | ||
|
|
11d76e7d8c | ||
|
|
e76c0fbc63 | ||
|
|
fdc9956da3 | ||
|
|
f4addaa653 | ||
|
|
667964cc82 | ||
|
|
e1309e30b7 | ||
|
|
9403942ef7 | ||
|
|
84a75d9e70 | ||
|
|
c85ab66ae6 | ||
|
|
bf7f0f646b | ||
|
|
dcdf2a3d58 | ||
|
|
f8d8fc40a6 | ||
|
|
45d434a123 | ||
|
|
1834abe5bc | ||
|
|
d6321588f3 | ||
|
|
c17b10ff1d | ||
|
|
b125a56f86 | ||
|
|
c43ce3a17b | ||
|
|
b0b09616a8 | ||
|
|
ede5586ccc | ||
|
|
a1dcdffa53 | ||
|
|
35a11db58e |
63
README.md
63
README.md
@@ -1,6 +1,6 @@
|
||||
# Trains Server
|
||||
|
||||
## Auto-Magical Experiment Manager & Version Control for AI
|
||||
## Auto-Magical Experiment Manager & Version Control for AI - ε Devops Included!
|
||||
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
@@ -9,6 +9,14 @@
|
||||
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
|
||||
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
|
||||
|
||||
## v0.16 Upgrade Notice
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of Trains Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
@@ -62,15 +70,15 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch **trains-server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
|
||||
- Pre-built [GCP Custom Image](https://github.com/allegroai/trains-server/blob/master/docs/install_gcp.md)
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
|
||||
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
|
||||
## Connecting Trains to your trains-server
|
||||
|
||||
@@ -98,12 +106,32 @@ you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in yo
|
||||
for example http://localhost:8080.
|
||||
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
|
||||
|
||||
## Trains-Agent Services <a name="services"></a>
|
||||
|
||||
As of version 0.15 of **trains-server**, dockerized deployment includes a **Trains-Agent Services** container running as
|
||||
part of the docker container collection.
|
||||
|
||||
Trains-Agent Services is an extension of Trains-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Trains-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by Trains-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
You can also run the Trains-Agent Services manually, see details in [trains-agent services mode](https://github.com/allegroai/trains-agent#trains-agent-services-mode-)
|
||||
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
|
||||
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
|
||||
|
||||
## Advanced Functionality
|
||||
|
||||
**trains-server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
|
||||
* [Web login authentication](https://allegro.ai/docs/faq/faq/#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/docs/faq/faq/#watchdog)
|
||||
|
||||
## Restarting trains-server
|
||||
|
||||
@@ -152,18 +180,29 @@ To upgrade your existing **trains-server** deployment:
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
```
|
||||
|
||||
1. Configure the Trains-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, Trains-Agent Services will use the external
|
||||
public address of the **trains-server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
the Trains-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml pull
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/docs/faq/faq/#common-docker-upgrade-errors).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
|
||||
If you have any questions, look to the Trains [FAQ](https://allegro.ai/docs/faq/faq/), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
|
||||
@@ -15,6 +15,8 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@@ -38,15 +40,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -56,10 +54,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
mongo:
|
||||
|
||||
@@ -22,6 +22,9 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
@@ -37,15 +40,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -55,10 +54,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
@@ -73,6 +72,8 @@ services:
|
||||
volumes:
|
||||
- c:/opt/trains/logs:/var/log/trains
|
||||
- c:/opt/trains/data/fileserver:/mnt/fileserver
|
||||
- c:/opt/trains/config:/opt/trains/config
|
||||
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
@@ -84,7 +85,8 @@ services:
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
- mongodata:/data
|
||||
- c:/opt/trains/data/mongo/db:/data/db
|
||||
- c:/opt/trains/data/mongo/configdb:/data/configdb
|
||||
ports:
|
||||
- "27017:27017"
|
||||
|
||||
@@ -115,6 +117,3 @@ services:
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
mongodata:
|
||||
|
||||
@@ -10,6 +10,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@@ -22,8 +23,10 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||
TRAINS__apiserver__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
|
||||
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
@@ -39,15 +42,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -57,10 +56,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
@@ -75,6 +74,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
@@ -108,13 +108,43 @@ services:
|
||||
container_name: trains-webserver
|
||||
image: allegroai/trains:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
ports:
|
||||
- "8080:80"
|
||||
|
||||
agent-services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-agent-services
|
||||
image: allegroai/trains-agent-services:latest
|
||||
restart: unless-stopped
|
||||
privileged: true
|
||||
environment:
|
||||
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
|
||||
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
|
||||
TRAINS_API_HOST: http://apiserver:8008
|
||||
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
|
||||
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
|
||||
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
|
||||
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
|
||||
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
|
||||
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
|
||||
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
|
||||
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
|
||||
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
|
||||
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
|
||||
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
|
||||
TRAINS_WORKER_ID: "trains-services"
|
||||
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /opt/trains/agent:/root/.trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
@@ -50,26 +50,64 @@ To upgrade the AMI:
|
||||
|
||||
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
|
||||
|
||||
### Latest version AMI - v0.14.2 (auto update)<a name="autoupdate"></a>
|
||||
### Latest version AMI - v0.15.1 (auto update)<a name="autoupdate"></a>
|
||||
|
||||
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
|
||||
|
||||
* **eu-north-1** : ami-095cc888970c06e09
|
||||
* **ap-south-1** : ami-07019e7b3febea37e
|
||||
* **eu-west-3** : ami-0433d76badf430c16
|
||||
* **eu-west-2** : ami-05794c2b23ff79990
|
||||
* **eu-west-1** : ami-03e3bcabd1863d666
|
||||
* **ap-northeast-2** : ami-00f14188b66a5803e
|
||||
* **ap-northeast-1** : ami-005c93e30c99dab0c
|
||||
* **sa-east-1** : ami-0d819231779e7d264
|
||||
* **ca-central-1** : ami-0eff2fd400939d960
|
||||
* **ap-southeast-1** : ami-049b21bfa0d35c21c
|
||||
* **ap-southeast-2** : ami-0318b96a72d5da068
|
||||
* **eu-central-1** : ami-0cdb9d794340b9704
|
||||
* **us-east-2** : ami-0d846a080fc5a9345
|
||||
* **us-west-1** : ami-0ef970342625159bf
|
||||
* **us-west-2** : ami-04f3d13b75c642506
|
||||
* **us-east-1** : ami-01bef4da91280a322
|
||||
* **eu-north-1** : ami-0f30c84b905d354b9
|
||||
* **ap-south-1** : ami-050e7acec52c8c74e
|
||||
* **eu-west-3** : ami-03911c5b5bc77ef75
|
||||
* **eu-west-2** : ami-0a5ed8aa2573ccc70
|
||||
* **eu-west-1** : ami-0a53c65e922ec0611
|
||||
* **ap-northeast-2** : ami-08cd017a37b8e8aab
|
||||
* **ap-northeast-1** : ami-056b3ca1ad5af9322
|
||||
* **sa-east-1** : ami-01ddc9325bafb400c
|
||||
* **ca-central-1** : ami-0fc3cbbd982b18b45
|
||||
* **ap-southeast-1** : ami-04c7a358df7002ef5
|
||||
* **ap-southeast-2** : ami-0eeaf54231b4ae22a
|
||||
* **eu-central-1** : ami-00b8e44041f8175fd
|
||||
* **us-east-2** : ami-0ac7deebb3f738f6d
|
||||
* **us-west-1** : ami-06bc07deb8b8c44d6
|
||||
* **us-west-2** : ami-01ba85ffe79a422f1
|
||||
* **us-east-1** : ami-04cf5a66cb4928ac3
|
||||
|
||||
### v0.15.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0cd314e267426d1b7
|
||||
* **ap-south-1** : ami-086182cbe29151f96
|
||||
* **eu-west-3** : ami-0062366012182815b
|
||||
* **eu-west-2** : ami-022b8f2e32a9d18d0
|
||||
* **eu-west-1** : ami-0d8cf60446e09aa3d
|
||||
* **ap-northeast-2** : ami-0d4c168a815b56889
|
||||
* **ap-northeast-1** : ami-0daf7887db1053ae4
|
||||
* **sa-east-1** : ami-020a759a3ba4ff22b
|
||||
* **ca-central-1** : ami-0c10b5e04b707f3e3
|
||||
* **ap-southeast-1** : ami-0f61bb3529a165fcd
|
||||
* **ap-southeast-2** : ami-032dcdc82749c66c5
|
||||
* **eu-central-1** : ami-08f364f32d2eb3bae
|
||||
* **us-east-2** : ami-0b7efc3591803eba4
|
||||
* **us-west-1** : ami-08b2df27b0ada6faf
|
||||
* **us-west-2** : ami-0693029c4bad28816
|
||||
* **us-east-1** : ami-0200954fa9c2819ff
|
||||
|
||||
### v0.15.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0bef15c03eab64c0c
|
||||
* **ap-south-1** : ami-06ac6248e583e2cd2
|
||||
* **eu-west-3** : ami-0541d86ef47a5714e
|
||||
* **eu-west-2** : ami-01381ef4c4ed22482
|
||||
* **eu-west-1** : ami-064626a0dd38b21f1
|
||||
* **ap-northeast-2** : ami-0a2490a7a3a8aa675
|
||||
* **ap-northeast-1** : ami-063f1de819a2524b8
|
||||
* **sa-east-1** : ami-07980486741b94987
|
||||
* **ca-central-1** : ami-0ced3b8b21ded839e
|
||||
* **ap-southeast-1** : ami-0c493c5093fde8741
|
||||
* **ap-southeast-2** : ami-0320a727eccb8dc6c
|
||||
* **eu-central-1** : ami-0aa85cfc78674c526
|
||||
* **us-east-2** : ami-01791485051e1880c
|
||||
* **us-west-1** : ami-0d8eade4d5888ea73
|
||||
* **us-west-2** : ami-02ceaef72cdf60f7e
|
||||
* **us-east-1** : ami-0fc3f9d1d0eba1d62
|
||||
|
||||
### v0.14.2 (static update)
|
||||
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
|
||||
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
|
||||
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
|
||||
#### Default Trains Server Service ports
|
||||
The service port numbers on our Trains Server GCP Custom Image are:
|
||||
|
||||
- Web application: `8080`
|
||||
- API Server: `8008`
|
||||
- File Server: `8081`
|
||||
|
||||
#### Default Trains Server Storage paths
|
||||
The persistent storage configuration:
|
||||
|
||||
- MongoDB: `/opt/trains/data/mongo/`
|
||||
@@ -49,10 +52,25 @@ The minimum recommended requirements for Trains Server are:
|
||||
|
||||
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
## Network and Security
|
||||
|
||||
Please make sure your instance is properly secured.
|
||||
|
||||
If not specifically set, a GCP instance will use default firewall rules that allow public access to various ports.
|
||||
If your instance is open for public access, we recommend you follow best practices for access management, including:
|
||||
- Allow access only to the specific ports used by Trains Server (see [Default Trains Server Service ports](#default-trains-server-service-ports)). Remember to allow access to port `443` if `https` access is configured for your instance.
|
||||
- Configure Trains Server to use fixed user names and passwords (see [Can I add web login authentication to trains-server?](./faq.md#web-auth))
|
||||
|
||||
## Released versions
|
||||
|
||||
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
|
||||
|
||||
### Latest version image (v0.14.1)
|
||||
### Latest version image
|
||||
|
||||
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz
|
||||
|
||||
### All released images
|
||||
|
||||
- v0.15.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-1.tar.gz
|
||||
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
|
||||
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz
|
||||
@@ -1,6 +1,6 @@
|
||||
# Launching the **trains-server** Docker in Linux or macOS
|
||||
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
|
||||
For Linux users:
|
||||
|
||||
@@ -16,20 +16,20 @@ To launch **trains-server** on Linux or macOS:
|
||||
|
||||
1. Verify the Docker CE installation. Execute the command:
|
||||
|
||||
sudo docker run hello-world
|
||||
|
||||
docker run hello-world
|
||||
|
||||
The expected is output is:
|
||||
|
||||
Hello from Docker!
|
||||
This message shows that your installation appears to be working correctly.
|
||||
To generate this message, Docker took the following steps:
|
||||
|
||||
|
||||
1. The Docker client contacted the Docker daemon.
|
||||
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
|
||||
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
|
||||
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
|
||||
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
|
||||
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
||||
sudo chmod +x /usr/local/bin/docker-compose
|
||||
@@ -42,12 +42,12 @@ To launch **trains-server** on Linux or macOS:
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
sudo service docker restart
|
||||
|
||||
|
||||
macOS:
|
||||
|
||||
|
||||
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
|
||||
sysctl -w vm.max_map_count=262144
|
||||
|
||||
|
||||
|
||||
1. Remove any previous installation of **trains-server**.
|
||||
|
||||
@@ -64,28 +64,28 @@ To launch **trains-server** on Linux or macOS:
|
||||
sudo mkdir -p /opt/trains/logs
|
||||
sudo mkdir -p /opt/trains/config
|
||||
sudo mkdir -p /opt/trains/data/fileserver
|
||||
|
||||
|
||||
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
|
||||
|
||||
|
||||
1. Grant access to the Dockers.
|
||||
|
||||
Linux:
|
||||
|
||||
sudo chown -R 1000:1000 /opt/trains
|
||||
|
||||
|
||||
macOS:
|
||||
|
||||
|
||||
sudo chown -R $(whoami):staff /opt/trains
|
||||
|
||||
1. Download the **trains-server** docker-compose YAML file.
|
||||
|
||||
cd /opt/trains
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
|
||||
|
||||
1. Run `docker-compose` with the downloaded configuration file.
|
||||
|
||||
sudo docker-compose -f docker-compose.yml up
|
||||
|
||||
docker-compose -f docker-compose.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
|
||||
|
||||
* Web server on port `8080`
|
||||
@@ -94,4 +94,4 @@ To launch **trains-server** on Linux or macOS:
|
||||
|
||||
## Next Step
|
||||
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
@@ -46,7 +50,23 @@ class BasicConfig:
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
@staticmethod
|
||||
def _read_extra_env_config_values():
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _read_env_paths(key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
return
|
||||
@@ -64,12 +84,17 @@ class BasicConfig:
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
|
||||
self._config = reduce(
|
||||
lambda config, path: ConfigTree.merge_configs(
|
||||
config, self._read_recursive(path, verbose=verbose), copy_trees=True
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
[self.folder] + extra_config_paths,
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
download {
|
||||
# Add response headers requesting no caching for served files
|
||||
disable_browser_caching: false
|
||||
|
||||
# Cache timeout to be set for downloaded files
|
||||
cache_timeout_sec: 300
|
||||
}
|
||||
|
||||
cors {
|
||||
|
||||
@@ -10,12 +10,14 @@ from flask_cors import CORS
|
||||
|
||||
from config import config
|
||||
|
||||
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
if os.environ.get("TRAINS_UPLOAD_FOLDER"):
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER")
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get("fileserver.download.cache_timeout_sec", 5 * 60)
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
@@ -57,12 +59,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--upload-folder",
|
||||
"-u",
|
||||
default="/mnt/fileserver",
|
||||
default=DEFAULT_UPLOAD_FOLDER,
|
||||
help="Upload folder (default %(default)s)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
if app.config.get("UPLOAD_FOLDER") is None:
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
|
||||
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "2.7.0"
|
||||
__version__ = "2.9.0"
|
||||
|
||||
@@ -47,6 +47,7 @@ _error_codes = {
|
||||
128: ('invalid_task_output', 'invalid task output'),
|
||||
129: ('task_publish_in_progress', 'Task publish in progress'),
|
||||
130: ('task_not_found', 'task not found'),
|
||||
131: ('events_not_added', 'events not added'),
|
||||
|
||||
# Models
|
||||
200: ('model_error', 'general task error'),
|
||||
|
||||
@@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apierrors import errors
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
@@ -206,10 +207,10 @@ class DomainField(fields.StringField):
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
from mongoengine.base import BaseDocument
|
||||
|
||||
from apimodels import DictField
|
||||
from apimodels import DictField, ListField
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
@@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
'inc',
|
||||
'dec',
|
||||
'push',
|
||||
'push_all',
|
||||
'pop',
|
||||
'pull',
|
||||
'pull_all',
|
||||
'add_to_set',
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split('__')
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == 'set':
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == 'unset':
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return '.'.join(parts), cls._normalize_mongo_value(value)
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
@@ -62,3 +63,7 @@ class PagedRequest(models.Base):
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from typing import Sequence
|
||||
from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
from config import config
|
||||
from utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=10000)
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
@@ -21,7 +24,15 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
items_types=str,
|
||||
validators=[
|
||||
Length(
|
||||
minimum_value=1,
|
||||
maximum_value=config.get(
|
||||
"services.tasks.multi_task_histogram_limit", 10
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +51,19 @@ class DebugImagesRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
asc = auto()
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
|
||||
@@ -6,6 +6,10 @@ from apimodels.base import UpdateResponse
|
||||
from apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
projects = fields.ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
|
||||
11
server/apimodels/organization.py
Normal file
11
server/apimodels/organization.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
@@ -1,5 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apimodels import ListField
|
||||
from apimodels.organization import TagsRequest
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
project = fields.StringField()
|
||||
@@ -10,7 +13,5 @@ class GetHyperParamReq(ProjectReq):
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apimodels import DictField, ListField
|
||||
from apimodels.base import UpdateResponse
|
||||
@@ -92,6 +94,10 @@ class PingRequest(TaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
@@ -99,7 +105,10 @@ class CloneRequest(TaskRequest):
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_hyperparams = DictField()
|
||||
new_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
@@ -109,3 +118,76 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsResponse(models.Base):
|
||||
added = ListField([str])
|
||||
updated = ListField([str])
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
@@ -13,13 +12,14 @@ from jsonmodels.fields import (
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import make_default, ListField, EnumField
|
||||
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@@ -61,26 +61,21 @@ class IdNameEntry(Base):
|
||||
name = StringField()
|
||||
|
||||
|
||||
class WorkerEntry(Base):
|
||||
class WorkerEntry(Base, JsonSerializableMixin):
|
||||
key = StringField() # not required due to migration issues
|
||||
id = StringField(required=True)
|
||||
user = EmbeddedField(IdNameEntry)
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**json.loads(s))
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
|
||||
@@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
|
||||
from database.model.task.metrics import MetricEventStats
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
@@ -45,17 +43,10 @@ class MetricScrollState(Base):
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base):
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
|
||||
def to_json(self):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**loads(s))
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
@@ -65,7 +56,12 @@ class DebugImagesResult(object):
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
STATE_EXPIRATION_SECONDS = 3600
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@property
|
||||
def _max_workers(self):
|
||||
@@ -76,7 +72,7 @@ class DebugImagesIterator:
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.STATE_EXPIRATION_SECONDS,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
@@ -92,27 +88,31 @@ class DebugImagesIterator:
|
||||
if not self.es.indices.exists(es_index):
|
||||
return DebugImagesResult()
|
||||
|
||||
unique_metrics = set(metrics)
|
||||
state = self.cache_manager.get_state(state_id) if state_id else None
|
||||
if not state:
|
||||
state = DebugImageEventsScrollState(
|
||||
id=database.utils.id(),
|
||||
metrics=self._init_metric_states(es_index, list(unique_metrics)),
|
||||
)
|
||||
else:
|
||||
state_metrics = set((m.task, m.name) for m in state.metrics)
|
||||
if state_metrics != unique_metrics:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"while getting debug images events", scroll_id=state_id
|
||||
)
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state)
|
||||
for metric_state in state.metrics:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
try:
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
@@ -125,10 +125,8 @@ class DebugImagesIterator:
|
||||
state.metrics,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.cache_manager.set_state(state)
|
||||
|
||||
return res
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||
@@ -210,7 +208,11 @@ class DebugImagesIterator:
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
@@ -253,7 +255,7 @@ class DebugImagesIterator:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -300,6 +302,7 @@ class DebugImagesIterator:
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
@@ -370,7 +373,7 @@ class DebugImagesIterator:
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
@@ -389,7 +392,7 @@ class DebugImagesIterator:
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
@@ -16,12 +15,14 @@ import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from bll.event.event_metrics import EventMetrics, EventType
|
||||
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@@ -29,15 +30,9 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult(object):
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.ib(factory=list)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@@ -47,12 +42,28 @@ class EventBLL(object):
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
return self._metrics
|
||||
|
||||
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
|
||||
@staticmethod
|
||||
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not task_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
@@ -62,19 +73,34 @@ class EventBLL(object):
|
||||
task_last_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
|
||||
errors_per_type = defaultdict(int)
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
if "type" not in event:
|
||||
raise errors.BadRequest("Event must have a 'type' field", event=event)
|
||||
event_type = event.get("type")
|
||||
if event_type is None:
|
||||
errors_per_type["Event must have a 'type' field"] += 1
|
||||
continue
|
||||
|
||||
event_type = event["type"].replace(" ", "_")
|
||||
event_type = event_type.replace(" ", "_")
|
||||
if event_type not in EVENT_TYPES:
|
||||
raise errors.BadRequest(
|
||||
"Invalid event type {}".format(event_type),
|
||||
event=event,
|
||||
types=EVENT_TYPES,
|
||||
)
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
|
||||
@@ -110,7 +136,6 @@ class EventBLL(object):
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
@@ -120,89 +145,74 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is not None:
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
)
|
||||
else:
|
||||
es_action["_routing"] = task_id
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
if task_ids:
|
||||
# verify task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
extra_msg = None
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
extra_msg = "or task published"
|
||||
res = Task.objects(query).only("id")
|
||||
if len(res) < len(task_ids):
|
||||
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
extra_msg, company=company_id, ids=invalid_task_ids
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
errors_in_bulk = []
|
||||
added = 0
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_in_bulk.append(info)
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
|
||||
# who's events were successful
|
||||
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
|
||||
added = min(added, len(actions))
|
||||
|
||||
return added, errors_in_bulk
|
||||
if not added:
|
||||
raise errors.bad_request.EventsNotAdded(**errors_per_type)
|
||||
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
@@ -220,9 +230,25 @@ class EventBLL(object):
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
last_event = last_events[metric_hash][variant_hash]
|
||||
event_iter = event.get("iter", 0)
|
||||
event_timestamp = event.get("timestamp", 0)
|
||||
value = event.get("value")
|
||||
if value is not None and (
|
||||
(event_iter, event_timestamp)
|
||||
>= (
|
||||
last_event.get("iter", event_iter),
|
||||
last_event.get("timestamp", event_timestamp),
|
||||
)
|
||||
):
|
||||
event_data = {
|
||||
k: event[k]
|
||||
for k in ("value", "metric", "variant", "iter", "timestamp")
|
||||
if k in event
|
||||
}
|
||||
event_data["min_value"] = min(value, last_event.get("min_value", value))
|
||||
event_data["max_value"] = max(value, last_event.get("max_value", value))
|
||||
last_events[metric_hash][variant_hash] = event_data
|
||||
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
"""
|
||||
@@ -265,7 +291,13 @@ class EventBLL(object):
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=["value", "metric", "variant"],
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -290,6 +322,9 @@ class EventBLL(object):
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -310,14 +345,9 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, scroll="1h", routing=task_id
|
||||
)
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
@@ -345,7 +375,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -361,7 +391,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -381,6 +411,9 @@ class EventBLL(object):
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -390,13 +423,11 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
must = []
|
||||
if last_iterations_per_plot is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
@@ -419,32 +450,41 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(tasks)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
# scroll id may be missing when queering a totally empty DB
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
@@ -457,6 +497,8 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
@@ -470,20 +512,16 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
@@ -502,27 +540,23 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
@@ -558,7 +592,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@@ -590,14 +624,14 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
@@ -627,7 +661,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@@ -674,7 +708,7 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@@ -695,7 +729,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -705,7 +739,7 @@ class EventBLL(object):
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -727,8 +761,6 @@ class EventBLL(object):
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
)
|
||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Callable, Iterable
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
@@ -16,7 +16,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -30,14 +30,18 @@ class EventType(Enum):
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 50
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
MAX_METRICS_COUNT = 100
|
||||
MAX_VARIANTS_COUNT = 100
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@property
|
||||
def _max_concurrency(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
@@ -51,15 +55,48 @@ class EventMetrics:
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
return self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=[task_id],
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average,
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
@@ -72,159 +109,115 @@ class EventMetrics:
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||
raise errors.BadRequest(
|
||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
|
||||
len(task_ids),
|
||||
)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
override_projection=("id", "name", "company"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
ret = self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average_per_task,
|
||||
)
|
||||
companies = {t.company for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
for metric_data in ret.values():
|
||||
for variant_data in metric_data.values():
|
||||
for task_id, task_data in variant_data.items():
|
||||
task_data["name"] = task_name_by_id[task_id]
|
||||
|
||||
return ret
|
||||
|
||||
TaskMetric = Tuple[str, str, str]
|
||||
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _split_metrics_by_max_aggs_count(
|
||||
self, task_metrics: Sequence[TaskMetric]
|
||||
) -> Iterable[Sequence[TaskMetric]]:
|
||||
"""
|
||||
Return task metrics in groups where amount of task metrics in each group
|
||||
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
|
||||
variants while always preserving all their tasks in the same group
|
||||
"""
|
||||
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield task_metrics
|
||||
return
|
||||
|
||||
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
|
||||
groups = []
|
||||
for group in tm_grouped.values():
|
||||
groups.append(group)
|
||||
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield list(itertools.chain(*groups))
|
||||
groups = []
|
||||
|
||||
if groups:
|
||||
yield list(itertools.chain(*groups))
|
||||
|
||||
return
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
task_ids: Sequence[str],
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
get_func: Callable[
|
||||
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
|
||||
],
|
||||
) -> dict:
|
||||
"""
|
||||
Group metrics per interval length and execute get_func for each group in parallel
|
||||
:param company_id: id of the company
|
||||
:params task_ids: ids of the tasks to collect data for
|
||||
:param samples: maximum number of samples per metric
|
||||
:param get_func: callable that given metric names for the same interval
|
||||
performs histogram aggregation for the metrics and return the aggregated data
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
intervals = self._get_metric_intervals(
|
||||
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
es_index=es_index,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
intervals = list(
|
||||
itertools.chain.from_iterable(
|
||||
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
|
||||
for i, tms in intervals
|
||||
)
|
||||
)
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
|
||||
intervals,
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return ret
|
||||
return res
|
||||
|
||||
def _get_metric_intervals(
|
||||
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return metric variants grouped by interval value with 10% rounding
|
||||
For samples==0 return empty list
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
default_intervals = [(1, [])]
|
||||
if not samples:
|
||||
return default_intervals
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -233,88 +226,78 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return default_intervals
|
||||
return []
|
||||
|
||||
intervals = [
|
||||
(
|
||||
task["key"],
|
||||
metric["key"],
|
||||
variant["key"],
|
||||
self._calculate_metric_interval(variant, samples),
|
||||
)
|
||||
for task in aggs_result["tasks"]["buckets"]
|
||||
for metric in task["metrics"]["buckets"]
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
metric_intervals = []
|
||||
upper_border = 0
|
||||
interval_metrics = None
|
||||
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
|
||||
if not interval_metrics or interval > upper_border:
|
||||
interval_metrics = []
|
||||
metric_intervals.append((interval, interval_metrics))
|
||||
upper_border = interval + int(interval * 0.1)
|
||||
interval_metrics.append((task, metric, variant))
|
||||
|
||||
return metric_intervals
|
||||
|
||||
@staticmethod
|
||||
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(metric_variant, "count/value")
|
||||
if not count or count < samples:
|
||||
return 1
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(metric_variant, "min_index/value", default=0)
|
||||
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
|
||||
return max(1, int(max_index - min_index + 1) // samples)
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
index_range = max_index - min_index + 1
|
||||
interval = max(1, math.ceil(float(index_range) / samples))
|
||||
max_samples = math.ceil(float(index_range) / interval)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
interval,
|
||||
max_samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
Note: the function works with a single task only
|
||||
"""
|
||||
|
||||
assert len(task_ids) == 1
|
||||
interval, task_metrics = metrics_interval
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
@@ -335,61 +318,6 @@ class EventMetrics:
|
||||
]
|
||||
return metrics
|
||||
|
||||
def _get_scalar_average_per_task(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, task_metrics = metrics_interval
|
||||
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {
|
||||
"field": "task",
|
||||
"size": self.MAX_TASKS_COUNT,
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
task["key"]: key.get_iterations_data(task)
|
||||
for task in variant["tasks"]["buckets"]
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
@@ -398,69 +326,55 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_metrics_and_tasks(
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
es_index: str,
|
||||
aggs: dict,
|
||||
task_ids: Sequence[str],
|
||||
task_metrics: Sequence[TaskMetric],
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
if task_metrics:
|
||||
condition = {
|
||||
"should": [
|
||||
self._build_metric_terms(task, metric, variant)
|
||||
for task, metric, variant in task_metrics
|
||||
]
|
||||
}
|
||||
else:
|
||||
condition = {"must": [{"terms": {"task": task_ids}}]}
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"bool": condition},
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
|
||||
"""
|
||||
Build query term for a metric + variant
|
||||
"""
|
||||
return {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence[Tuple]:
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [(tid, []) for tid in task_ids]
|
||||
return {}
|
||||
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_concurrency) as pool:
|
||||
with ThreadPoolExecutor(self._max_concurrency) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
@@ -488,7 +402,7 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
||||
114
server/bll/event/log_events_iterator.py
Normal file
114
server/bll/event/log_events_iterator.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
es_index,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
@@ -4,7 +4,7 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from apimodels import StringEnum
|
||||
from utilities.stringenum import StringEnum
|
||||
from bll.util import extract_properties_to_lists
|
||||
from config import config
|
||||
|
||||
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
|
||||
18
server/bll/model/__init__.py
Normal file
18
server/bll/model/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from database.model.model import Model
|
||||
from database.utils import get_company_or_none_constraint
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
193
server/bll/organization/__init__.py
Normal file
193
server/bll/organization/__init__.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
self.redis = redis
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
|
||||
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
):
|
||||
"""
|
||||
Project None means 'from all company projects'
|
||||
The key is built in the way that scanning company keys for 'all company projects'
|
||||
will not return the keys related to the particular company projects and vice versa.
|
||||
So that we can have a fine grain control on what redis keys to invalidate
|
||||
"""
|
||||
filter_str = None
|
||||
if filter_:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
Return the dictionary of tags per tags field name
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
redis_keys = [
|
||||
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
|
||||
for f in fields
|
||||
]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, project, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
fields = [
|
||||
field
|
||||
for field, update in (
|
||||
(self._tags_field, tags),
|
||||
(self._system_tags_field, system_tags),
|
||||
)
|
||||
if update is not None
|
||||
]
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
)
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_, project=project
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
1
server/bll/project/__init__.py
Normal file
1
server/bll/project/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .project_bll import ProjectBLL
|
||||
33
server/bll/project/project_bll.py
Normal file
33
server/bll/project/project_bll.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
@@ -18,7 +18,6 @@ log = config.logger(__file__)
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
DOC_TYPE = "metrics"
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
@@ -66,7 +65,6 @@ class QueueMetrics:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type=self.EsKeys.DOC_TYPE,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
@@ -93,7 +91,6 @@ class QueueMetrics:
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
doc_type=self.EsKeys.DOC_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -109,7 +106,7 @@ class QueueMetrics:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
from typing import Optional, TypeVar, Generic, Type
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
import database
|
||||
from timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _do_nothing(_: T):
|
||||
return
|
||||
|
||||
|
||||
class RedisCacheManager(Generic[T]):
|
||||
"""
|
||||
Class for store/retreive of state objects from redis
|
||||
Class for store/retrieve of state objects from redis
|
||||
|
||||
self.state_class - class of the state
|
||||
self.redis - instance of redis
|
||||
@@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]):
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
):
|
||||
"""
|
||||
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
||||
If no then create a new one with randomly generated id
|
||||
Yield the state and write it back to redis once the user code block exits
|
||||
:param state_id: id of the state to retrieve
|
||||
:param init_state: user callback to init the newly created state
|
||||
If not passed then no init except for the id generation is done
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
self.set_state(state)
|
||||
|
||||
@@ -19,7 +19,7 @@ from config.info import get_deployment_type
|
||||
from database.model import Company, User
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
from utilities.json import dumps
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from version import __version__ as current_version
|
||||
@@ -237,7 +237,6 @@ class StatisticsReporter:
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -280,7 +279,7 @@ class StatisticsReporter:
|
||||
]
|
||||
return {
|
||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||
for group in Task.aggregate(*pipeline)
|
||||
for group in Task.aggregate(pipeline)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,5 +4,4 @@ from .utils import (
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
ParameterKeyEscaper,
|
||||
)
|
||||
|
||||
229
server/bll/task/hyperparams.py
Normal file
229
server/bll/task/hyperparams.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import (
|
||||
HyperParamKey,
|
||||
HyperParamItem,
|
||||
ReplaceHyperparams,
|
||||
Configuration,
|
||||
)
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
class HyperParams:
|
||||
_properties_section = "properties"
|
||||
|
||||
@classmethod
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_params_list(
|
||||
cls, items: Dict[str, Dict[str, ParamsItem]]
|
||||
) -> Sequence[dict]:
|
||||
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
||||
return [
|
||||
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _normalize_params(cls, params: Sequence) -> bool:
|
||||
"""
|
||||
Lower case properties section and return True if it is the only section
|
||||
"""
|
||||
for p in params:
|
||||
if p.section.lower() == cls._properties_section:
|
||||
p.section = cls._properties_section
|
||||
|
||||
return all(p.section == cls._properties_section for p in params)
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{section}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
||||
return {
|
||||
ParameterKeyEscaper.escape(section): {
|
||||
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
||||
for param in params
|
||||
}
|
||||
for section, params in sections.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configurations(
|
||||
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
only = ["id"]
|
||||
if names:
|
||||
only.extend(
|
||||
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
||||
)
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"configuration": [
|
||||
c.to_proper_dict()
|
||||
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
||||
]
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
) -> Dict[str, list]:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration=Sequence[str]
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@staticmethod
|
||||
def _get_task_for_update(
|
||||
company: str, id: str, allow_all_statuses: bool = False
|
||||
) -> Task:
|
||||
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
if task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apierrors import errors
|
||||
from bll.task import ChangeStatusRequest
|
||||
from config import config
|
||||
from database.model.task.task import TaskStatus, Task
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class NonResponsiveTasksWatchdog:
|
||||
threads = ThreadsManager()
|
||||
|
||||
class _Settings:
|
||||
"""
|
||||
Retrieves watchdog settings from the config file
|
||||
The properties are not cached so that the updates in
|
||||
the config file are reflected
|
||||
"""
|
||||
|
||||
_prefix = "services.tasks.non_responsive_tasks_watchdog"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return config.get(f"{self._prefix}.enabled", True)
|
||||
|
||||
@property
|
||||
def watch_interval_sec(self):
|
||||
return config.get(f"{self._prefix}.watch_interval_sec", 900)
|
||||
|
||||
@property
|
||||
def threshold_sec(self):
|
||||
return config.get(f"{self._prefix}.threshold_sec", 7200)
|
||||
|
||||
settings = _Settings()
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
stopped = cls.cleanup_tasks(
|
||||
threshold_sec=cls.settings.threshold_sec
|
||||
)
|
||||
log.info(f"{stopped} non-responsive tasks stopped")
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
sleep(watch_interval)
|
||||
|
||||
@classmethod
|
||||
def cleanup_tasks(cls, threshold_sec):
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(seconds=threshold_sec)
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
|
||||
"id", "name", "status", "project", "last_update"
|
||||
)
|
||||
)
|
||||
log.info(f"{len(tasks)} non-responsive tasks found")
|
||||
if not tasks:
|
||||
return 0
|
||||
|
||||
err_count = 0
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
try:
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
except errors.bad_request.FailedChangingTaskStatus:
|
||||
err_count += 1
|
||||
|
||||
return len(tasks) - err_count
|
||||
201
server/bll/task/param_utils.py
Normal file
201
server/bll/task/param_utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import dpath
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.task.task import Task
|
||||
from tools import safe_get
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
hyperparams_default_section = "Args"
|
||||
hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
if default_section is None:
|
||||
return None, full_name
|
||||
|
||||
section, _, name = full_name.partition("/")
|
||||
if section != tf_define_section:
|
||||
return default_section, full_name
|
||||
|
||||
if not name:
|
||||
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
|
||||
return section, name
|
||||
|
||||
|
||||
def _get_full_param_name(param: dict) -> str:
|
||||
section = param.get("section")
|
||||
if section != tf_define_section:
|
||||
return param["name"]
|
||||
|
||||
return "/".join((section, param["name"]))
|
||||
|
||||
|
||||
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
removed = 0
|
||||
if not data:
|
||||
return removed
|
||||
|
||||
if with_sections:
|
||||
for section, section_data in list(data.items()):
|
||||
removed += _remove_legacy_params(section_data)
|
||||
if not section_data:
|
||||
"""If section is empty after removing legacy params then delete it"""
|
||||
del data[section]
|
||||
else:
|
||||
for key, param in list(data.items()):
|
||||
if param.get("type") == hyperparams_legacy_type:
|
||||
removed += 1
|
||||
del data[key]
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
|
||||
return [
|
||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||
]
|
||||
|
||||
|
||||
def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
"""
|
||||
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
previous_data = previous_task.to_proper_dict().get(new_params_field)
|
||||
removed = _remove_legacy_params(
|
||||
previous_data, with_sections=default_section is not None
|
||||
)
|
||||
if not legacy_params and not removed:
|
||||
# if we only need to delete legacy fields from the db
|
||||
# but they are not there then there is no point to proceed
|
||||
continue
|
||||
|
||||
fields_update = {new_params_field: previous_data}
|
||||
params_unprepare_from_saved(fields_update)
|
||||
fields.update(fields_update)
|
||||
|
||||
for full_name, value in legacy_params.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_params_field, section, name)))
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
"""
|
||||
Unescape all section and param names for hyper params and configuration
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
)
|
||||
|
||||
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
for prefix in (
|
||||
"hyperparams.",
|
||||
"-hyperparams.",
|
||||
"configuration.",
|
||||
"-configuration.",
|
||||
):
|
||||
paths = [
|
||||
_process_path(path) if path.startswith(prefix) else path for path in paths
|
||||
]
|
||||
return paths
|
||||
@@ -1,10 +1,11 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from random import random
|
||||
from time import sleep
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
|
||||
import dpath
|
||||
import pymongo.results
|
||||
import six
|
||||
from mongoengine import Q
|
||||
@@ -14,6 +15,7 @@ import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@@ -27,25 +29,38 @@ from database.model.task.task import (
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
external_task_types,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
threads = ThreadsManager("TaskBLL")
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
@@ -70,25 +85,24 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
required_dataset=None,
|
||||
only_fields=None,
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
only_fields = [only_fields]
|
||||
else:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
qs = Task.objects(id=task_id, company=company_id)
|
||||
if only_fields:
|
||||
qs = (
|
||||
qs.only(only_fields)
|
||||
if isinstance(only_fields, string_types)
|
||||
else qs.only(*only_fields)
|
||||
)
|
||||
qs = qs.only(
|
||||
"status", "input"
|
||||
) # make sure all fields we rely on here are also returned
|
||||
task = qs.first()
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -96,17 +110,12 @@ class TaskBLL(object):
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
if required_dataset and required_dataset not in (
|
||||
entry.dataset for entry in task.input.view.entries
|
||||
):
|
||||
raise errors.bad_request.InvalidId(
|
||||
"not in input view", dataset=required_dataset
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(company_id, task_ids, only=None, allow_public=False):
|
||||
def assert_exists(
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
ids = set(task_ids)
|
||||
@@ -117,14 +126,13 @@ class TaskBLL(object):
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
res = q.only(*only)
|
||||
count = len(res)
|
||||
else:
|
||||
count = q.count()
|
||||
res = q.first()
|
||||
if count != len(ids):
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
return res
|
||||
|
||||
if return_tasks:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
@@ -166,17 +174,32 @@ class TaskBLL(object):
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
execution_overrides["parameters"] = {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
@@ -203,27 +226,59 @@ class TaskBLL(object):
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task):
|
||||
assert isinstance(task, Task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project and not Project.get_for_writing(
|
||||
company=task.company, id=task.project
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
cls.validate_execution_model(task)
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
@@ -263,7 +318,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(*pipeline)
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
@@ -312,10 +367,12 @@ class TaskBLL(object):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
if path[-1] == "value":
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
@@ -327,7 +384,7 @@ class TaskBLL(object):
|
||||
|
||||
metric_stats = {
|
||||
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
|
||||
)
|
||||
for metric_key, metric_data in last_events.items()
|
||||
}
|
||||
@@ -575,81 +632,35 @@ class TaskBLL(object):
|
||||
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start_non_responsive_tasks_watchdog(cls):
|
||||
log = config.logger("non_responsive_tasks_watchdog")
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
|
||||
)
|
||||
)
|
||||
watch_interval = config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec", 900
|
||||
)
|
||||
sleep(watch_interval)
|
||||
while not ThreadsManager.terminating:
|
||||
try:
|
||||
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(
|
||||
status__in=relevant_status, last_update__lt=ref_time
|
||||
).only("id", "name", "status", "project", "last_update")
|
||||
)
|
||||
|
||||
if tasks:
|
||||
|
||||
log.info(f"Stopping {len(tasks)} non-responsive tasks")
|
||||
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
|
||||
log.info(f"Done")
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
|
||||
sleep(watch_interval)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_execution_parameters(
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[str]]:
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"execution.parameters": {"$exists": True, "$gt": {}},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
|
||||
{"$unwind": "$parameters"},
|
||||
{"$group": {"_id": "$parameters.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
@@ -666,7 +677,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(*pipeline), None)
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
@@ -675,7 +686,12 @@ class TaskBLL(object):
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r["_id"])
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import TypeVar, Callable, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from boltons.dictutils import OneToOne
|
||||
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
@@ -172,26 +171,3 @@ def split_by(
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
|
||||
|
||||
class ParameterKeyEscaper:
|
||||
_mapping = OneToOne({".": "%2E", "$": "%24"})
|
||||
|
||||
@classmethod
|
||||
def escape(cls, value):
|
||||
""" Quote a parameter key """
|
||||
value = value.strip().replace("%", "%%")
|
||||
for c, r in cls._mapping.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _unescape(cls, value):
|
||||
for c, r in cls._mapping.inv.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def unescape(cls, value):
|
||||
""" Unquote a quoted parameter key """
|
||||
return "%".join(map(cls._unescape, value.split("%%")))
|
||||
|
||||
@@ -35,14 +35,21 @@ class SetFieldsResolver:
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = set_fields
|
||||
self.fields = {
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
|
||||
@@ -21,6 +21,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
@@ -33,8 +34,8 @@ log = config.logger(__file__)
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es if es is not None else es_factory.connect("workers")
|
||||
self.redis = redis if redis is not None else redman.connection("workers")
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@property
|
||||
@@ -49,6 +50,7 @@ class WorkerBLL:
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -58,6 +60,7 @@ class WorkerBLL:
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -91,6 +94,7 @@ class WorkerBLL:
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
@@ -113,12 +117,15 @@ class WorkerBLL:
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -129,6 +136,9 @@ class WorkerBLL:
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
@@ -146,6 +156,7 @@ class WorkerBLL:
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
@@ -160,6 +171,12 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
@@ -223,7 +240,7 @@ class WorkerBLL:
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(*projection)
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
@@ -369,7 +386,6 @@ class WorkerBLL:
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type="stat",
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
|
||||
@@ -25,7 +25,6 @@ class WorkerStats:
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ class WorkerStats:
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]:
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
@@ -87,7 +86,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{request.interval}s",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
@@ -216,7 +215,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -15,7 +16,7 @@ from pyparsing import (
|
||||
|
||||
DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';'
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
@@ -57,7 +58,7 @@ class BasicConfig:
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
@@ -77,7 +78,7 @@ class BasicConfig:
|
||||
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
|
||||
]
|
||||
if invalid:
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
|
||||
return [path for path in paths if path.is_dir()]
|
||||
|
||||
def _load(self, verbose=True):
|
||||
|
||||
@@ -26,6 +26,17 @@
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
# not including the pre_populate
|
||||
db_init_timout: 120
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
@@ -34,11 +45,16 @@
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_file: "/path/to/export.zip"
|
||||
fail_on_error: false
|
||||
elastic {
|
||||
probing {
|
||||
# settings for inital probing of elastic connection
|
||||
max_retries: 4
|
||||
timeout: 30
|
||||
}
|
||||
upgrade_monitoring {
|
||||
v16_migration_verification: true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
hosts: [{host: "127.0.0.1", port: 9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9200}]
|
||||
hosts: [{host:"127.0.0.1", port:9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
|
||||
@@ -13,17 +13,21 @@
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
apiserver {
|
||||
role: "system"
|
||||
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
|
||||
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
|
||||
}
|
||||
webserver {
|
||||
role: "system"
|
||||
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
tests {
|
||||
role: "user"
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
16
server/config/default/services/auth.conf
Normal file
16
server/config/default/services/auth.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -6,4 +6,8 @@ ignore_iteration {
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
}
|
||||
|
||||
3
server/config/default/services/organization.conf
Normal file
3
server/config/default/services/organization.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
8
server/config/default/services/projects.conf
Normal file
8
server/config/default/services/projects.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured_order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
||||
@@ -1,4 +1,6 @@
|
||||
non_responsive_tasks_watchdog {
|
||||
enabled: true
|
||||
|
||||
# In-progress tasks older than this value in seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
@@ -9,4 +11,6 @@ non_responsive_tasks_watchdog {
|
||||
artifacts {
|
||||
update_attempts: 10
|
||||
update_retry_msec: 500
|
||||
}
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
||||
@@ -41,3 +41,7 @@ def get_deployment_type() -> str:
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
es_connection_error = False
|
||||
|
||||
@@ -79,6 +79,10 @@ def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_hosts():
|
||||
return [entry.host for entry in get_entries()]
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from mongoengine import (
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
@@ -125,17 +128,39 @@ def contains_empty_key(d):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
class DictValidationMixin:
|
||||
"""
|
||||
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||
"""
|
||||
|
||||
def _safe_validate(self: DictField, value):
|
||||
if not isinstance(value, dict):
|
||||
self.error("Only dictionaries may be used in a DictField")
|
||||
|
||||
if key_not_string(value):
|
||||
msg = "Invalid dictionary key - documents must have only string keys"
|
||||
self.error(msg)
|
||||
|
||||
if key_starts_with_dollar(value):
|
||||
self.error(
|
||||
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||
)
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
|
||||
class SafeMapField(MapField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
@@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
@@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
def key(v): return str(itemgetter(self._ordering)(v))
|
||||
|
||||
def key(v):
|
||||
return str(itemgetter(self._ordering)(v))
|
||||
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
@@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
@@ -43,6 +45,7 @@ class Role(object):
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
last_used = DateTimeField()
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
@@ -34,7 +35,12 @@ class AuthDocument(Document):
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
||||
def to_proper_dict(
|
||||
self: Union["ProperDictMixin", Document],
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
@@ -71,6 +77,8 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -91,11 +99,48 @@ class GetMixin(PropsMixin):
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_next = _default
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
||||
cls: Union["GetMixin", Document],
|
||||
company,
|
||||
id,
|
||||
*,
|
||||
_only=None,
|
||||
include_public=False,
|
||||
**kwargs,
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
@@ -162,17 +207,7 @@ class GetMixin(PropsMixin):
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -216,6 +251,47 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
is included.
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
@@ -272,6 +348,20 @@ class GetMixin(PropsMixin):
|
||||
return []
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
@@ -409,7 +499,27 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
@@ -429,7 +539,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
@@ -437,13 +549,14 @@ class GetMixin(PropsMixin):
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
if only:
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
@@ -460,6 +573,8 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
the None values are sorted in the end regardless of the sorting order.
|
||||
If the first order field is a user defined parameter (either from execution.parameters,
|
||||
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
@@ -477,7 +592,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
@@ -500,20 +617,29 @@ class GetMixin(PropsMixin):
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
v
|
||||
for k, v in cls._field_collation_overrides.items()
|
||||
if order_field.startswith(k)
|
||||
)
|
||||
if collation_override:
|
||||
query_sets = [
|
||||
qs.collation(collation=collation_override) for qs in query_sets
|
||||
]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if only:
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*only) for qs in query_sets]
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields())
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
@@ -524,7 +650,8 @@ class GetMixin(PropsMixin):
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
@@ -593,7 +720,13 @@ class UpdateMixin(object):
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
||||
def safe_update(
|
||||
cls: Union["UpdateMixin", Document],
|
||||
company_id,
|
||||
id,
|
||||
partial_update_dict,
|
||||
injected_update=None,
|
||||
):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
@@ -610,7 +743,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
|
||||
@classmethod
|
||||
def aggregate(
|
||||
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
|
||||
cls: Union["DbModelMixin", Document],
|
||||
pipeline: Sequence[dict],
|
||||
allow_disk_use=None,
|
||||
**kwargs,
|
||||
) -> CommandCursor:
|
||||
"""
|
||||
Aggregate objects of this document class according to the provided pipeline.
|
||||
@@ -625,7 +761,32 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
)
|
||||
return cls.objects.aggregate(*pipeline, **kwargs)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
@@ -647,5 +808,5 @@ def validate_id(cls, company, **kwargs):
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeDictField
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
@@ -18,7 +19,9 @@ class Model(DbModelMixin, Document):
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
@@ -34,6 +37,21 @@ class Model(DbModelMixin, Document):
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"user",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
@@ -44,8 +62,8 @@ class Model(DbModelMixin, Document):
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default="", user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
@@ -54,3 +72,4 @@ class Model(DbModelMixin, Document):
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from mongoengine import MapField, IntField
|
||||
from database.fields import NoneType, UnionField, SafeMapField
|
||||
|
||||
|
||||
class ModelLabels(MapField):
|
||||
class ModelLabels(SafeMapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
||||
super(ModelLabels, self).__init__(
|
||||
field=UnionField(types=(int, NoneType)), *args, **kwargs
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
if value and (len(set(value.values())) < len(value)):
|
||||
non_empty_values = list(filter(None, value.values()))
|
||||
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
@@ -36,7 +36,11 @@ class Project(AttributedDocument):
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True))
|
||||
system_tags = ListField(StringField(required=True))
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -4,11 +4,10 @@ from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -7,6 +7,10 @@ from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class SettingKeys:
|
||||
server__uuid = "server.uuid"
|
||||
|
||||
|
||||
class Settings(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
@@ -47,7 +51,7 @@ class Settings(DbModelMixin, Document):
|
||||
""" Adds a new key/value settings. Fails if key already exists. """
|
||||
key = key.strip(sep)
|
||||
try:
|
||||
res = Settings(key=key, value=value).save(force_insert=True)
|
||||
res = cls(key=key, value=value).save(force_insert=True)
|
||||
return bool(res)
|
||||
except NotUniqueError:
|
||||
return False
|
||||
|
||||
@@ -18,7 +18,7 @@ from database.fields import (
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import ProperDictMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
@@ -49,13 +49,13 @@ class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
class Script(EmbeddedDocument, ProperDictMixin):
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(required=True)
|
||||
repository = StringField(default="")
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
entry_point = StringField(default="")
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
diff = StringField()
|
||||
@@ -84,7 +84,23 @@ class Artifact(EmbeddedDocument):
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
section = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
@@ -100,9 +116,29 @@ class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
class TaskType(object):
|
||||
training = "training"
|
||||
testing = "testing"
|
||||
inference = "inference"
|
||||
data_processing = "data_processing"
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -113,6 +149,7 @@ class Task(AttributedDocument):
|
||||
"parent",
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
@@ -140,6 +177,12 @@ class Task(AttributedDocument):
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
@@ -158,14 +201,19 @@ class Task(AttributedDocument):
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script, default=Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
|
||||
@@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
|
||||
@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1:], s, d)
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
@@ -96,6 +96,7 @@ class _ProxyManager:
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
|
||||
orig_field = field
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection).union({"id"})
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [
|
||||
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(
|
||||
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
||||
from mongoengine.queryset.visitor import (
|
||||
QueryCompilerVisitor,
|
||||
SimplificationVisitor,
|
||||
QCombination,
|
||||
QNode,
|
||||
)
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
@@ -17,17 +23,16 @@ class RegexWrapper(object):
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
|
||||
def to_query(self, document):
|
||||
def to_query(self: Union["RegexMixin", QNode], document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self, other, operation):
|
||||
def _combine(self: Union["RegexMixin", QNode], other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, 'empty', True):
|
||||
if getattr(other, "empty", True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
|
||||
@@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if callable(desc):
|
||||
if issubclass(desc, Document):
|
||||
if not desc.objects(id=value).only("id"):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
elif callable(desc):
|
||||
try:
|
||||
desc(value)
|
||||
except TypeError:
|
||||
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
||||
except Exception as ex:
|
||||
raise ParseCallError(str(ex), field=field)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(
|
||||
value, desc
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s" % desc.__name__, field=field
|
||||
)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only(
|
||||
"id"
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
@@ -4,53 +4,54 @@ Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
|
||||
session.mount('http://', adapter)
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
def apply_mappings_to_host(host: str):
|
||||
def _send_mapping(f):
|
||||
def _send_template(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
es_server = host
|
||||
url = f"{es_server}/_template/{f.stem}"
|
||||
|
||||
session.delete(url)
|
||||
r = session.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(data),
|
||||
)
|
||||
return {"mapping": f.stem, "result": r.text}
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
return [
|
||||
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
|
||||
]
|
||||
if key:
|
||||
files = (p / key).glob("*.json")
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("hosts", nargs="+")
|
||||
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
nargs="+",
|
||||
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
for host in parse_args().hosts:
|
||||
print(">>>>> Applying mapping to " + host)
|
||||
res = apply_mappings_to_host(host)
|
||||
print(res)
|
||||
args = parse_args()
|
||||
print(">>>>> Applying mapping to " + str(args.hosts))
|
||||
res = apply_mappings_to_cluster(args.hosts, args.key)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from furl import furl
|
||||
import logging
|
||||
from time import sleep
|
||||
from typing import Type, Optional, Sequence, Any, Union
|
||||
|
||||
import urllib3.exceptions
|
||||
from elasticsearch import Elasticsearch, exceptions
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
from elastic.apply_mappings import apply_mappings_to_cluster
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -15,13 +20,94 @@ class MissingElasticConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
class ElasticConnectionError(Exception):
|
||||
"""
|
||||
Exception when could not connect to elastic during init
|
||||
"""
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionErrorFilter(logging.Filter):
|
||||
def __init__(
|
||||
self,
|
||||
level: Optional[Union[int, str]] = None,
|
||||
err_type: Optional[Type] = None,
|
||||
args_prefix: Optional[Sequence[Any]] = None,
|
||||
):
|
||||
super(ConnectionErrorFilter, self).__init__()
|
||||
if level is None:
|
||||
self.level = None
|
||||
else:
|
||||
try:
|
||||
self.level = int(level)
|
||||
except ValueError:
|
||||
self.level = logging.getLevelName(level)
|
||||
|
||||
self.err_type = err_type
|
||||
self.args = args_prefix and tuple(args_prefix)
|
||||
self.last_blocked = None
|
||||
|
||||
def filter(self, record):
|
||||
try:
|
||||
allow = (
|
||||
(self.err_type is None or record.exc_info[0] != self.err_type)
|
||||
and (self.level is None or record.levelno != self.level)
|
||||
and (self.args is None or record.args[: len(self.args)] != self.args)
|
||||
)
|
||||
if not allow:
|
||||
self.last_blocked = record
|
||||
return allow
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def check_elastic_empty() -> bool:
|
||||
"""
|
||||
Check for elasticsearch connection
|
||||
Use probing settings and not the default es cluster ones
|
||||
so that we can handle correctly the connection rejects due to ES not fully started yet
|
||||
:return:
|
||||
"""
|
||||
cluster_conf = es_factory.get_cluster_config("events")
|
||||
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
|
||||
timeout = config.get("apiserver.elastic.probing.timeout", 30)
|
||||
|
||||
es_logger = logging.getLogger("elasticsearch")
|
||||
log_filter = ConnectionErrorFilter(
|
||||
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
|
||||
)
|
||||
|
||||
try:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
return True
|
||||
except exceptions.ConnectionError as ex:
|
||||
if retry >= max_retries - 1:
|
||||
raise ElasticConnectionError(
|
||||
f"Error connecting to Elasticsearch: {str(ex)}"
|
||||
)
|
||||
log.warn(
|
||||
f"Could not connect to ElasticSearch Service. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
|
||||
)
|
||||
sleep(timeout)
|
||||
finally:
|
||||
es_logger.removeFilter(log_filter)
|
||||
|
||||
|
||||
def init_es_data():
|
||||
for name in es_factory.get_all_cluster_names():
|
||||
cluster_conf = es_factory.get_cluster_config(name)
|
||||
hosts_config = cluster_conf.get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration(f"for cluster '{name}'")
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"_routing": {
|
||||
"required": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": { "type": "date" },
|
||||
"task": { "type": "keyword" },
|
||||
"type": { "type": "keyword" },
|
||||
"worker": { "type": "keyword" },
|
||||
"timestamp": { "type": "date" },
|
||||
"iter": { "type": "long" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
40
server/elastic/mappings/events/events.json
Normal file
40
server/elastic/mappings/events/events.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
server/elastic/mappings/events/events_log.json
Normal file
15
server/elastic/mappings/events/events_log.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"index_patterns": "events-log-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"msg": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"level": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
server/elastic/mappings/events/events_plot.json
Normal file
12
server/elastic/mappings/events/events_plot.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"index_patterns": "events-plot-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"index_patterns": "events-training_debug_image-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"url": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-log-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"msg": { "type":"text", "index": false },
|
||||
"level": { "type":"keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"template": "events-plot-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"plot_str": { "type":"text", "index": false }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-training_debug_image-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"key": { "type": "keyword" },
|
||||
"url": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"metrics": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"template": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"stat": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": { "type": "date" },
|
||||
"worker": { "type": "keyword" },
|
||||
"category": { "type": "keyword" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" },
|
||||
"unit": { "type": "keyword" },
|
||||
"task": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
server/elastic/mappings/workers/worker_stats.json
Normal file
37
server/elastic/mappings/workers/worker_stats.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"category": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"unit": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,6 +65,10 @@ def connect(cluster_name):
|
||||
return _instances[cluster_name]
|
||||
|
||||
|
||||
def get_all_cluster_names():
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
|
||||
def get_cluster_config(cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.auth import Role
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
from .migration import _apply_migrations
|
||||
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
|
||||
from .pre_populate import PrePopulate
|
||||
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
|
||||
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
@@ -11,59 +13,76 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
log = config.logger(__package__)
|
||||
|
||||
|
||||
def _pre_populate(company_id: str, zip_file: str):
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Invalid pre-populate zip file: {zip_file}"
|
||||
if config.get("apiserver.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
log.info(f"Pre-populating using {zip_file}")
|
||||
|
||||
PrePopulate.import_from_zip(
|
||||
zip_file,
|
||||
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
|
||||
if isinstance(zip_files, str):
|
||||
zip_files = [zip_files]
|
||||
for p in map(Path, zip_files):
|
||||
if p.is_file():
|
||||
yield p
|
||||
if p.is_dir():
|
||||
yield from p.glob("*.zip")
|
||||
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
|
||||
|
||||
|
||||
def pre_populate_data():
|
||||
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
|
||||
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
|
||||
|
||||
PrePopulate.update_featured_projects_order()
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
empty_dbs = _apply_migrations(log)
|
||||
_apply_migrations(log)
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company(log)
|
||||
company_id = _ensure_company(get_default_company(), "trains", log)
|
||||
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
|
||||
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
|
||||
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
fixed_mode = FixedUser.enabled()
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
for user, credentials in config.get("secure.credentials", {}).items():
|
||||
user_data = {
|
||||
"name": user,
|
||||
"role": credentials.role,
|
||||
"email": f"{user}@example.com",
|
||||
"key": credentials.user_key,
|
||||
"secret": credentials.user_secret,
|
||||
}
|
||||
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
|
||||
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
|
||||
if credentials.role == Role.user:
|
||||
_ensure_backend_user(user_id, company_id, credentials.display_name)
|
||||
|
||||
PrePopulate.import_from_zip(zip_file, user_id=user_id)
|
||||
|
||||
users = [
|
||||
{
|
||||
"name": "apiserver",
|
||||
"role": Role.system,
|
||||
"email": "apiserver@example.com",
|
||||
},
|
||||
{
|
||||
"name": "webserver",
|
||||
"role": Role.system,
|
||||
"email": "webserver@example.com",
|
||||
},
|
||||
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
|
||||
]
|
||||
|
||||
for user in users:
|
||||
credentials = config.get(f"secure.credentials.{user['name']}")
|
||||
user["key"] = credentials.user_key
|
||||
user["secret"] = credentials.user_secret
|
||||
_ensure_auth_user(user, company_id, log=log)
|
||||
|
||||
if FixedUser.enabled():
|
||||
if fixed_mode:
|
||||
log.info("Fixed users mode is enabled")
|
||||
FixedUser.validate()
|
||||
|
||||
if FixedUser.guest_enabled():
|
||||
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
|
||||
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
ensure_fixed_user(user, company_id, log=log)
|
||||
ensure_fixed_user(user, log=log)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
except Exception as ex:
|
||||
|
||||
@@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
|
||||
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger) -> bool:
|
||||
def check_mongo_empty() -> bool:
|
||||
return not all(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
|
||||
def get_last_server_version() -> Version:
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
return previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger):
|
||||
"""
|
||||
Apply migrations as found in the migration dir.
|
||||
Returns a boolean indicating whether the database was empty prior to migration.
|
||||
@@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
|
||||
empty_dbs = not any(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
empty_dbs = check_mongo_empty()
|
||||
last_version = get_last_server_version()
|
||||
|
||||
try:
|
||||
new_scripts = {
|
||||
@@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
|
||||
).save()
|
||||
|
||||
log.info("Finished mongodb migrations")
|
||||
|
||||
return empty_dbs
|
||||
|
||||
@@ -1,31 +1,390 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from os.path import splitext
|
||||
from typing import List, Optional, Any, Type, Set, Dict
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Optional,
|
||||
Any,
|
||||
Type,
|
||||
Set,
|
||||
Dict,
|
||||
Sequence,
|
||||
Tuple,
|
||||
BinaryIO,
|
||||
Union,
|
||||
Mapping,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import dpath
|
||||
import mongoengine
|
||||
from tqdm import tqdm
|
||||
from boltons.iterutils import chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from bll.event import EventBLL
|
||||
from bll.task.param_utils import (
|
||||
split_param_name,
|
||||
hyperparams_default_section,
|
||||
hyperparams_legacy_type,
|
||||
)
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model import EntityVisibility
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from database.utils import get_options
|
||||
from tools import safe_get
|
||||
from utilities import json
|
||||
from .user import _ensure_backend_user
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
||||
cls._export(zfile, experiments, projects)
|
||||
event_bll = EventBLL()
|
||||
events_file_suffix = "_events"
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
metadata_filename = "metadata.json"
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
artifacts_ext = ".artifacts"
|
||||
img_source_regex = re.compile(
|
||||
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
self.file = file
|
||||
self.empty = True
|
||||
|
||||
def __enter__(self):
|
||||
self._write("[")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self._write("\n]")
|
||||
|
||||
def _write(self, data: str):
|
||||
self.file.write(data.encode("utf-8"))
|
||||
|
||||
def write(self, line: str):
|
||||
if not self.empty:
|
||||
self._write(",")
|
||||
self._write("\n" + line)
|
||||
self.empty = False
|
||||
|
||||
@staticmethod
|
||||
def _get_last_update_time(entity) -> datetime:
|
||||
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(cls, filename: str, user_id: str = None):
|
||||
def _check_for_update(
|
||||
cls, map_file: Path, entities: dict, metadata_hash: str
|
||||
) -> Tuple[bool, Sequence[str]]:
|
||||
if not map_file.is_file():
|
||||
return True, []
|
||||
|
||||
files = []
|
||||
try:
|
||||
map_data = json.loads(map_file.read_text())
|
||||
files = map_data.get("files", [])
|
||||
for file in files:
|
||||
if not Path(file).is_file():
|
||||
return True, files
|
||||
|
||||
new_times = {
|
||||
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
||||
for item in chain.from_iterable(entities.values())
|
||||
}
|
||||
old_times = map_data.get("entities", {})
|
||||
|
||||
if set(new_times.keys()) != set(old_times.keys()):
|
||||
return True, files
|
||||
|
||||
for id_, new_timestamp in new_times.items():
|
||||
if new_timestamp != old_times[id_]:
|
||||
return True, files
|
||||
|
||||
if metadata_hash != map_data.get("metadata_hash", ""):
|
||||
return True, files
|
||||
|
||||
except Exception as ex:
|
||||
print("Error reading map file. " + str(ex))
|
||||
return True, files
|
||||
|
||||
return False, files
|
||||
|
||||
@classmethod
|
||||
def _write_update_file(
|
||||
cls,
|
||||
map_file: Path,
|
||||
entities: dict,
|
||||
created_files: Sequence[str],
|
||||
metadata_hash: str,
|
||||
):
|
||||
map_file.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
files=created_files,
|
||||
entities={
|
||||
entity.id: cls._get_last_update_time(entity)
|
||||
for entity in chain.from_iterable(entities.values())
|
||||
},
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
||||
def is_fileserver_link(a: str) -> bool:
|
||||
a = a.lower()
|
||||
if a.startswith("https://files."):
|
||||
return True
|
||||
if a.startswith("http"):
|
||||
parsed = urlparse(a)
|
||||
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
|
||||
"8081"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
|
||||
print(
|
||||
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
|
||||
)
|
||||
|
||||
return fileserver_links
|
||||
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
artifacts_path: str = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
) -> Sequence[str]:
|
||||
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||
raise ValueError("Invalid task statuses")
|
||||
|
||||
file = Path(filename)
|
||||
entities = cls._resolve_entities(
|
||||
experiments=experiments, projects=projects, task_statuses=task_statuses
|
||||
)
|
||||
|
||||
hash_ = hashlib.md5()
|
||||
if metadata:
|
||||
meta_str = json.dumps(metadata)
|
||||
hash_.update(meta_str.encode())
|
||||
metadata_hash = hash_.hexdigest()
|
||||
else:
|
||||
meta_str, metadata_hash = "", ""
|
||||
|
||||
map_file = file.with_suffix(".map")
|
||||
updated, old_files = cls._check_for_update(
|
||||
map_file, entities=entities, metadata_hash=metadata_hash
|
||||
)
|
||||
if not updated:
|
||||
print(f"There are no updates from the last export")
|
||||
return old_files
|
||||
|
||||
for old in old_files:
|
||||
old_path = Path(old)
|
||||
if old_path.is_file():
|
||||
old_path.unlink()
|
||||
|
||||
with ZipFile(file, **cls.zip_args) as zfile:
|
||||
if metadata:
|
||||
zfile.writestr(cls.metadata_filename, meta_str)
|
||||
artifacts = cls._export(
|
||||
zfile,
|
||||
entities=entities,
|
||||
hash_=hash_,
|
||||
tag_entities=tag_exported_entities,
|
||||
)
|
||||
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
||||
file.replace(file_with_hash)
|
||||
created_files = [str(file_with_hash)]
|
||||
|
||||
artifacts = cls._filter_artifacts(artifacts)
|
||||
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
|
||||
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
|
||||
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
||||
created_files.append(str(artifacts_file))
|
||||
|
||||
cls._write_update_file(
|
||||
map_file,
|
||||
entities=entities,
|
||||
created_files=created_files,
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
artifacts_path: str,
|
||||
company_id: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
):
|
||||
metadata = None
|
||||
|
||||
with ZipFile(filename) as zfile:
|
||||
cls._import(zfile, user_id)
|
||||
try:
|
||||
with zfile.open(cls.metadata_filename) as f:
|
||||
metadata = json.loads(f.read())
|
||||
|
||||
meta_public = metadata.get("public")
|
||||
if company_id is None and meta_public is not None:
|
||||
company_id = "" if meta_public else get_default_company()
|
||||
|
||||
if not user_id:
|
||||
meta_user_id = metadata.get("user_id", "")
|
||||
meta_user_name = metadata.get("user_name", "")
|
||||
user_id, user_name = meta_user_id, meta_user_name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
# Make sure we won't end up with an invalid company ID
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
user_id = _ensure_backend_user(
|
||||
user_id=user_id, user_name=user_name, company_id="",
|
||||
)
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
|
||||
if artifacts_file.is_file():
|
||||
print(f"Unzipping artifacts into {artifacts_path}")
|
||||
with ZipFile(artifacts_file) as zfile:
|
||||
zfile.extractall(artifacts_path)
|
||||
|
||||
@classmethod
|
||||
def upgrade_zip(cls, filename) -> Sequence:
|
||||
hash_ = hashlib.md5()
|
||||
task_file = cls._get_base_filename(Task) + ".json"
|
||||
temp_file = Path("temp.zip")
|
||||
file = Path(filename)
|
||||
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
|
||||
for file_info in reader.filelist:
|
||||
if file_info.orig_filename == task_file:
|
||||
with reader.open(file_info) as f:
|
||||
content = cls._upgrade_tasks(f)
|
||||
else:
|
||||
content = reader.read(file_info)
|
||||
writer.writestr(file_info, content)
|
||||
hash_.update(content)
|
||||
|
||||
base_file_name, _, old_hash = file.stem.rpartition("_")
|
||||
new_hash = hash_.hexdigest()
|
||||
if old_hash == new_hash:
|
||||
print(f"The file {filename} was not updated")
|
||||
temp_file.unlink()
|
||||
return []
|
||||
|
||||
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
|
||||
temp_file.replace(new_file)
|
||||
upadated = [str(new_file)]
|
||||
|
||||
artifacts_file = file.with_suffix(cls.artifacts_ext)
|
||||
if artifacts_file.is_file():
|
||||
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
|
||||
artifacts_file.replace(new_artifacts)
|
||||
upadated.append(str(new_artifacts))
|
||||
|
||||
print(f"File {str(file)} replaced with {str(new_file)}")
|
||||
file.unlink()
|
||||
|
||||
return upadated
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict):
|
||||
for old_param_field, new_param_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy = safe_get(task_data, old_param_field)
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_param_field, section, name)))
|
||||
if not safe_get(task_data, new_path):
|
||||
new_param = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(task_data, new_path, new_param)
|
||||
dpath.delete(task_data, old_param_field)
|
||||
|
||||
@classmethod
|
||||
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
|
||||
"""
|
||||
Build content array that contains fixed tasks from the passed file
|
||||
For each task the old execution.parameters and model.design are
|
||||
converted to the new structure.
|
||||
The fix is done on Task objects (not the dictionary) so that
|
||||
the fields are serialized back in the same order as they were in the original file
|
||||
"""
|
||||
with BytesIO() as temp:
|
||||
with cls.JsonLinesWriter(temp) as w:
|
||||
for line in cls.json_lines(f):
|
||||
task_data = Task.from_json(line).to_proper_dict()
|
||||
cls._upgrade_task_data(task_data)
|
||||
new_task = Task(**task_data)
|
||||
w.write(new_task.to_json())
|
||||
return temp.getvalue()
|
||||
|
||||
@classmethod
|
||||
def update_featured_projects_order(cls):
|
||||
featured_order = config.get("services.projects.featured_order", [])
|
||||
if not featured_order:
|
||||
return
|
||||
|
||||
def get_index(p: Project):
|
||||
for index, entry in enumerate(featured_order):
|
||||
if (
|
||||
entry.get("id", None) == p.id
|
||||
or entry.get("name", None) == p.name
|
||||
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
|
||||
):
|
||||
return index
|
||||
return 999
|
||||
|
||||
for project in Project.get_many_public(projection=["id", "name"]):
|
||||
featured_index = get_index(project)
|
||||
Project.objects(id=project.id).update(featured=featured_index)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
||||
) -> List[Any]:
|
||||
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
||||
) -> Sequence[Any]:
|
||||
ids = set(ids)
|
||||
items = list(cls.objects(id__in=list(ids)))
|
||||
resolved = {i.id for i in items}
|
||||
@@ -43,20 +402,24 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls, experiments: List[str] = None, projects: List[str] = None
|
||||
cls,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
|
||||
entities = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
entities[Project].update(cls._resolve_type(Project, projects))
|
||||
print("--> Reading project experiments...")
|
||||
objs = Task.objects(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
||||
query = Q(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
)
|
||||
if task_statuses:
|
||||
query &= Q(status__in=list(set(task_statuses)))
|
||||
objs = Task.objects(query)
|
||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||
|
||||
if experiments:
|
||||
@@ -69,85 +432,297 @@ class PrePopulate:
|
||||
project_ids = {p.id for p in entities[Project]}
|
||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
for task in entities[Task]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[Model] = set(Model.objects(id__in=list(model_ids)))
|
||||
|
||||
return entities
|
||||
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task):
|
||||
from database.model.task.task import TaskStatus
|
||||
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
|
||||
if not tags:
|
||||
return tags
|
||||
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
|
||||
|
||||
task.completed = None
|
||||
task.started = None
|
||||
if task.execution:
|
||||
task.execution.model = None
|
||||
task.execution.model_desc = None
|
||||
task.execution.model_labels = None
|
||||
if task.output:
|
||||
task.output.model = None
|
||||
@classmethod
|
||||
def _cleanup_model(cls, model: Model):
|
||||
model.company = ""
|
||||
model.user = ""
|
||||
model.tags = cls._filter_out_export_tags(model.tags)
|
||||
|
||||
task.status = TaskStatus.created
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task: Task):
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.created = datetime.utcnow()
|
||||
task.last_iteration = 0
|
||||
task.last_update = task.created
|
||||
task.status_changed = task.created
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
task.company = ""
|
||||
task.tags = cls._filter_out_export_tags(task.tags)
|
||||
if task.output:
|
||||
task.output.destination = None
|
||||
|
||||
@classmethod
|
||||
def _cleanup_project(cls, project: Project):
|
||||
project.user = ""
|
||||
project.company = ""
|
||||
project.tags = cls._filter_out_export_tags(project.tags)
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
from database.model.task.task import Task
|
||||
if entity_cls == Task:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == Model:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == Project:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
|
||||
try:
|
||||
for item in items:
|
||||
item.update(upsert=False, tags=sorted(item.tags + [tag]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _export_task_events(
|
||||
cls, task: Task, base_filename: str, writer: ZipFile, hash_
|
||||
) -> Sequence[str]:
|
||||
artifacts = []
|
||||
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
|
||||
print(f"Writing task events into {writer.filename}:{filename}")
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
scroll_id = None
|
||||
while True:
|
||||
res = cls.event_bll.get_task_events(
|
||||
task.company, task.id, scroll_id=scroll_id
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
scroll_id = res.next_scroll_id
|
||||
for event in res.events:
|
||||
event_type = event.get("type")
|
||||
if event_type == "training_debug_image":
|
||||
url = cls._get_fixed_url(event.get("url"))
|
||||
if url:
|
||||
event["url"] = url
|
||||
artifacts.append(url)
|
||||
elif event_type == "plot":
|
||||
plot_str: str = event.get("plot_str", "")
|
||||
for match in cls.img_source_regex.findall(plot_str):
|
||||
url = cls._get_fixed_url(match)
|
||||
if match != url:
|
||||
plot_str = plot_str.replace(match, url)
|
||||
artifacts.append(url)
|
||||
w.write(json.dumps(event))
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
|
||||
if not (url and url.lower().startswith("s3://")):
|
||||
return url
|
||||
try:
|
||||
fixed = furl(url)
|
||||
fixed.scheme = "https"
|
||||
fixed.host += ".s3.amazonaws.com"
|
||||
return fixed.url
|
||||
except Exception as ex:
|
||||
print(f"Failed processing link {url}. " + str(ex))
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def _export_entity_related_data(
|
||||
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
||||
):
|
||||
if entity_cls == Task:
|
||||
return [
|
||||
*cls._get_task_output_artifacts(entity),
|
||||
*cls._export_task_events(entity, base_filename, writer, hash_),
|
||||
]
|
||||
|
||||
if entity_cls == Model:
|
||||
entity.uri = cls._get_fixed_url(entity.uri)
|
||||
return [entity.uri] if entity.uri else []
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
|
||||
if not task.execution.artifacts:
|
||||
return []
|
||||
|
||||
for a in task.execution.artifacts:
|
||||
if a.mode == ArtifactModes.output:
|
||||
a.uri = cls._get_fixed_url(a.uri)
|
||||
|
||||
return [
|
||||
a.uri
|
||||
for a in task.execution.artifacts
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _export_artifacts(
|
||||
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
|
||||
):
|
||||
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
|
||||
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
|
||||
for path in unique_paths:
|
||||
path = path.lstrip("/")
|
||||
full_path = os.path.join(artifacts_path, path)
|
||||
if os.path.isfile(full_path):
|
||||
writer.write(full_path, path)
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@staticmethod
|
||||
def _get_base_filename(cls_: type):
|
||||
return f"{cls_.__module__}.{cls_.__name__}"
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
entities = cls._resolve_entities(experiments, projects)
|
||||
|
||||
for cls_, items in entities.items():
|
||||
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
if not items:
|
||||
continue
|
||||
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
)
|
||||
)
|
||||
filename = base_filename + ".json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with writer.open(filename, "w") as f:
|
||||
f.write("[\n".encode("utf-8"))
|
||||
last = len(items) - 1
|
||||
for i, item in enumerate(items):
|
||||
cls._cleanup_entity(cls_, item)
|
||||
f.write(item.to_json().encode("utf-8"))
|
||||
if i != last:
|
||||
f.write(",".encode("utf-8"))
|
||||
f.write("\n".encode("utf-8"))
|
||||
f.write("]\n".encode("utf-8"))
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for item in items:
|
||||
cls._cleanup_entity(cls_, item)
|
||||
w.write(item.to_json())
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
if tag_entities:
|
||||
cls._add_tag(items, now.strftime(cls.export_tag))
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _import(reader: ZipFile, user_id: str = None):
|
||||
for file_info in reader.filelist:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
def json_lines(file: BinaryIO):
|
||||
for line in file:
|
||||
clean = (
|
||||
line.decode("utf-8")
|
||||
.rstrip("\r\n")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip(",]")
|
||||
.strip()
|
||||
)
|
||||
if not clean:
|
||||
continue
|
||||
yield clean
|
||||
|
||||
with reader.open(file_info) as f:
|
||||
for item in tqdm(
|
||||
f.readlines(),
|
||||
desc=f"Writing {cls_.__name__.lower()}s into database",
|
||||
unit="doc",
|
||||
):
|
||||
item = (
|
||||
item.decode("utf-8")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip("]")
|
||||
.rstrip(",")
|
||||
.strip()
|
||||
)
|
||||
if not item:
|
||||
continue
|
||||
doc = cls_.from_json(item)
|
||||
if user_id is not None and hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.save(force_insert=True)
|
||||
@classmethod
|
||||
def _import(
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
company_id: str = "",
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
Start from entities since event import will require the tasks already in DB
|
||||
"""
|
||||
event_file_ending = cls.events_file_suffix + ".json"
|
||||
entity_files = (
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
)
|
||||
event_files = (
|
||||
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
for files, reader_func in (
|
||||
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
|
||||
(event_files, cls._import_events),
|
||||
):
|
||||
for file_info in files:
|
||||
with reader.open(file_info) as f:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
reader_func(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(
|
||||
cls,
|
||||
f: BinaryIO,
|
||||
full_name: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
):
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, Project):
|
||||
override_project_name = metadata.get("project_name", None)
|
||||
if override_project_name:
|
||||
if override_project_count:
|
||||
override_project_name = (
|
||||
f"{override_project_name} {override_project_count + 1}"
|
||||
)
|
||||
override_project_count += 1
|
||||
doc.name = override_project_name
|
||||
|
||||
doc.logo_url = metadata.get("logo_url", None)
|
||||
doc.logo_blob = metadata.get("logo_blob", None)
|
||||
|
||||
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
||||
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
)
|
||||
|
||||
doc.save()
|
||||
|
||||
if isinstance(doc, Task):
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
|
||||
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
)
|
||||
|
||||
@@ -9,28 +9,34 @@ from database.model.user import User
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger):
|
||||
ensure_credentials = {"key", "secret"}.issubset(user_data)
|
||||
if ensure_credentials:
|
||||
user = AuthUser.objects(
|
||||
credentials__match=Credentials(
|
||||
key=user_data["key"], secret=user_data["secret"]
|
||||
)
|
||||
).first()
|
||||
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
|
||||
key, secret = user_data.get("key"), user_data.get("secret")
|
||||
if not (key and secret):
|
||||
credentials = None
|
||||
else:
|
||||
creds = Credentials(key=key, secret=secret)
|
||||
|
||||
user = AuthUser.objects(credentials__match=creds).first()
|
||||
if user:
|
||||
if revoke:
|
||||
user.credentials = []
|
||||
user.save()
|
||||
return user.id
|
||||
|
||||
credentials = [] if revoke else [creds]
|
||||
|
||||
user_id = user_data.get("id", f"__{user_data['name']}__")
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
|
||||
user = AuthUser(
|
||||
id=user_data.get("id", f"__{user_data['name']}__"),
|
||||
id=user_id,
|
||||
name=user_data["name"],
|
||||
company=company_id,
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
|
||||
if ensure_credentials
|
||||
else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
user.save()
|
||||
@@ -52,23 +58,23 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
|
||||
if User.objects(id=user.user_id).first():
|
||||
def ensure_fixed_user(user: FixedUser, log: Logger):
|
||||
db_user = User.objects(company=user.company, id=user.user_id).first()
|
||||
if db_user:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
log.info(f"Updating user name: {user.name}")
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
data["role"] = Role.guest if user.is_guest else Role.user
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
|
||||
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user.user_id,
|
||||
company=company_id,
|
||||
name=user.name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
return _ensure_backend_user(user.user_id, user.company, user.name)
|
||||
|
||||
@@ -3,21 +3,18 @@ from uuid import uuid4
|
||||
|
||||
from bll.queue import QueueBLL
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings
|
||||
from database.model.settings import Settings, SettingKeys
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def _ensure_company(log: Logger):
|
||||
company_id = get_default_company()
|
||||
def _ensure_company(company_id, company_name, log: Logger):
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
@@ -37,4 +34,4 @@ def _ensure_default_queue(company):
|
||||
|
||||
|
||||
def _ensure_uuid():
|
||||
Settings.add_value("server.uuid", str(uuid4()))
|
||||
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))
|
||||
|
||||
58
server/mongo/migrations/0.15.0.py
Normal file
58
server/mongo/migrations/0.15.0.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections import Collection
|
||||
from typing import Sequence
|
||||
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
for collection_name in db.list_collection_names():
|
||||
if collection_name not in names:
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
||||
|
||||
|
||||
def migrate_auth(db: Database):
|
||||
"""
|
||||
Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
_drop_all_indices_from_collections(db, ["user"])
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
"""
|
||||
1. Sort tags and system tags
|
||||
2. Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
|
||||
fields = ("tags", "system_tags")
|
||||
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
|
||||
for collection_name in ("task", "model", "project", "queue"):
|
||||
collection = db[collection_name]
|
||||
for doc in collection.find(filter=query, projection=fields):
|
||||
update = {
|
||||
field: sorted(doc[field])
|
||||
for field in fields
|
||||
if doc.get(field)
|
||||
}
|
||||
if update:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": update})
|
||||
|
||||
_drop_all_indices_from_collections(
|
||||
db,
|
||||
[
|
||||
"company",
|
||||
"model",
|
||||
"project",
|
||||
"queue",
|
||||
"settings",
|
||||
"task",
|
||||
"task__trash",
|
||||
"user",
|
||||
"versions",
|
||||
],
|
||||
)
|
||||
36
server/mongo/migrations/0.16.0.py
Normal file
36
server/mongo/migrations/0.16.0.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
from bll.task.param_utils import (
|
||||
hyperparams_legacy_type,
|
||||
hyperparams_default_section,
|
||||
split_param_name,
|
||||
)
|
||||
from tools import safe_get
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
hyperparam_fields = ("execution.parameters", "hyperparams")
|
||||
configuration_fields = ("execution.model_desc", "configuration")
|
||||
collection: Collection = db["task"]
|
||||
for doc in collection.find(projection=hyperparam_fields + configuration_fields):
|
||||
set_commands = {}
|
||||
for (old_field, new_field), default_section in zip(
|
||||
(hyperparam_fields, configuration_fields),
|
||||
(hyperparams_default_section, None),
|
||||
):
|
||||
legacy = safe_get(doc, old_field, separator=".")
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_field, section, name)))
|
||||
# if safe_get(doc, new_path) is not None:
|
||||
# continue
|
||||
new_value = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_value["section"] = section
|
||||
set_commands[".".join(new_path)] = new_value
|
||||
if set_commands:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||
11
server/mongo/migrations/0.16.1.py
Normal file
11
server/mongo/migrations/0.16.1.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collection: Collection = db["project"]
|
||||
featured = "featured"
|
||||
query = {featured: {"$exists": False}}
|
||||
for doc in collection.find(filter=query, projection=()):
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {featured: 9999}},
|
||||
)
|
||||
@@ -1,7 +1,8 @@
|
||||
attrs>=19.1.0
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
fastjsonschema>=2.8
|
||||
Flask-Compress>=1.4.0
|
||||
Flask-Cors>=3.0.5
|
||||
@@ -14,17 +15,17 @@ Jinja2==2.10
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.7.2
|
||||
mongoengine==0.16.2
|
||||
mongoengine==0.19.1
|
||||
nested_dict>=1.61
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt>=1.3.0
|
||||
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
|
||||
pymongo==3.10.1
|
||||
python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.0,<3
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
||||
@@ -328,6 +328,11 @@ fixed_users_mode {
|
||||
description: "Fixed users mode enabled"
|
||||
type: boolean
|
||||
}
|
||||
server_errors {
|
||||
description: "Server initialization errors"
|
||||
type: object
|
||||
additionalProperties: True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
16
server/schema/services/debug.conf
Normal file
16
server/schema/services/debug.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
_description: "debugging utilities"
|
||||
ping {
|
||||
authorize: false
|
||||
"2.9" {
|
||||
description: "Ping server"
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties: {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -530,7 +530,7 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.7" {
|
||||
"2.9" {
|
||||
description: "Get 'log' events for this task"
|
||||
request {
|
||||
type: object
|
||||
@@ -548,15 +548,16 @@
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
||||
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order, unless order='asc'). Otherwise from the earliest to the latest ones (in timestamp ascending order, unless order='desc'). The default is True"
|
||||
}
|
||||
refresh {
|
||||
type: boolean
|
||||
description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
||||
from_timestamp {
|
||||
type: number
|
||||
description: "Epoch time in UTC ms to use as the navigation start. Optional. If not provided, reference timestamp is determined by the 'navigate_earlier' parameter (if true, reference timestamp is the last timestamp and if false, reference timestamp is the first timestamp)"
|
||||
}
|
||||
scroll_id {
|
||||
order {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
description: "If set, changes the order in which log events are returned based on the value of 'navigate_earlier'"
|
||||
enum: [asc, desc]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -576,10 +577,6 @@
|
||||
type: number
|
||||
description: "Total number of log events available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -856,7 +853,7 @@
|
||||
description: "Task ID"
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
@@ -894,7 +891,7 @@
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs"
|
||||
description: "List of task Task IDs. Maximum amount of tasks is 10"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
@@ -902,7 +899,7 @@
|
||||
}
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return. Optional, the default value is 6000"
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
48
server/schema/services/organization.conf
Normal file
48
server/schema/services/organization.conf
Normal file
@@ -0,0 +1,48 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,6 +196,52 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
projects {
|
||||
description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
@@ -359,6 +405,11 @@ get_all_ex {
|
||||
enum: [ active, archived ]
|
||||
default: active
|
||||
}
|
||||
non_public {
|
||||
description: "Return only non-public projects"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -481,8 +532,8 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.2" {
|
||||
description: """Get a list of all hyper parameter names used in tasks within the given project."""
|
||||
"2.9" {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
@@ -506,9 +557,9 @@ get_hyper_parameters {
|
||||
type: object
|
||||
properties {
|
||||
parameters {
|
||||
description: "A list of hyper parameter names"
|
||||
description: "A list of parameter sections and names"
|
||||
type: array
|
||||
items { type: string }
|
||||
items {type: object}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
@@ -522,3 +573,69 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company projects to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public projects to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -69,6 +69,17 @@ info {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.8": ${info."2.1"} {
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
uid {
|
||||
description: "Server UID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
endpoints {
|
||||
"2.1" {
|
||||
|
||||
@@ -254,6 +254,15 @@ _definitions {
|
||||
enum: [
|
||||
training
|
||||
testing
|
||||
inference
|
||||
data_processing
|
||||
application
|
||||
monitor
|
||||
controller
|
||||
optimizer
|
||||
service
|
||||
qc
|
||||
custom
|
||||
]
|
||||
}
|
||||
last_metrics_event {
|
||||
@@ -288,7 +297,80 @@ _definitions {
|
||||
"$ref": "#/definitions/last_metrics_event"
|
||||
}
|
||||
}
|
||||
|
||||
params_item {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. The combination of section and name should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
configuration_item {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Name of the parameter. Should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
param_key {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. If the name is ommitted then the corresponding operation is performed on the whole section"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
section_params {
|
||||
description: "Task section params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/params_item"
|
||||
}
|
||||
}
|
||||
replace_hyperparams_enum {
|
||||
type: string
|
||||
enum: [
|
||||
none,
|
||||
section,
|
||||
all
|
||||
]
|
||||
}
|
||||
task {
|
||||
type: object
|
||||
properties {
|
||||
@@ -409,9 +491,24 @@ _definitions {
|
||||
"$ref": "#/definitions/last_metrics_variants"
|
||||
}
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: "Gets task information"
|
||||
@@ -475,7 +572,11 @@ get_all {
|
||||
minimum: 1
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
description: """List of field names to order by. When search_text is used,
|
||||
'@text_score' can be used as a field representing the text score of returned documents.
|
||||
Use '-' prefix to specify descending order. Optional, recommended when using page.
|
||||
If the first order field is a hyper parameter or metric then string values are ordered
|
||||
according to numeric ordering rules where applicable"""
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
@@ -550,6 +651,31 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_types {
|
||||
"2.8" {
|
||||
description: "Get the list of task types used in the specified projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
types {
|
||||
description: "Unique list of the task types used in the requested projects"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
clone {
|
||||
"2.5" {
|
||||
description: "Clone an existing task"
|
||||
@@ -587,10 +713,28 @@ clone {
|
||||
description: "The project of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
new_task_hyperparams {
|
||||
description: "The hyper params for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
new_task_configuration {
|
||||
description: "The configuration for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
execution_overrides {
|
||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
validate_references {
|
||||
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -656,6 +800,20 @@ create {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -717,6 +875,20 @@ validate {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
@@ -867,6 +1039,20 @@ edit {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
@@ -901,6 +1087,11 @@ reset {
|
||||
properties.force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is 'completed'"
|
||||
}
|
||||
properties.clear_all {
|
||||
description: "Clear script and execution sections completely"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
@@ -1394,4 +1585,263 @@ add_or_update_artifacts {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company tasks to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public tasks to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_hyper_params {
|
||||
"2.9": {
|
||||
description: "Get the list of task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
params {
|
||||
type: object
|
||||
description: "Hyper parameters (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_hyper_params {
|
||||
"2.9" {
|
||||
description: "Add or update task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, hyperparams ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper parameters. The new ones will be added and the already existing ones will be updated"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/params_item"}
|
||||
}
|
||||
replace_hyperparams {
|
||||
description: """Can be set to one of the following:
|
||||
'all' - all the hyper parameters will be replaced with the provided ones
|
||||
'section' - the sections that present in the new parameters will be replaced with the provided parameters
|
||||
'none' (the default value) - only the specific parameters will be updated or added"""
|
||||
"$ref": "#/definitions/replace_hyperparams_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_hyper_params {
|
||||
"2.9": {
|
||||
description: "Delete task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, hyperparams ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
hyperparams {
|
||||
description: "List of hyper parameters to delete. In case a parameter with an empty name is passed all the section will be deleted"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/param_key" }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_configurations {
|
||||
"2.9": {
|
||||
description: "Get the list of task configurations"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
names {
|
||||
description: "Names of the configuration items to retreive. If not passed or empty then all the configurations will be retreived."
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
configurations {
|
||||
type: object
|
||||
description: "Configurations (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_configuration_names {
|
||||
"2.9": {
|
||||
description: "Get the list of task configuration items names"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
configurations {
|
||||
type: object
|
||||
description: "Names of task configuration items (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_configuration {
|
||||
"2.9" {
|
||||
description: "Add or update task configuration"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, configuration ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration items. The new ones will be added and the already existing ones will be updated"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/configuration_item"}
|
||||
}
|
||||
replace_configuration {
|
||||
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_configuration {
|
||||
"2.9": {
|
||||
description: "Delete task configuration items"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, configuration ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
configuration {
|
||||
description: "List of configuration itemss to delete"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,19 @@ get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
}
|
||||
"2.8": ${get_all."2.1"} {
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
active_in_projects {
|
||||
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
get_all {
|
||||
|
||||
@@ -135,6 +135,10 @@
|
||||
description: "Task currently being run by the worker"
|
||||
"$ref": "#/definitions/current_task_entry"
|
||||
}
|
||||
project {
|
||||
description: "Project in which currently executing task resides"
|
||||
"$ref": "#/definitions/id_name_entry"
|
||||
}
|
||||
queue {
|
||||
description: "Queue from which running task was taken"
|
||||
"$ref": "#/definitions/queue_entry"
|
||||
@@ -144,6 +148,11 @@
|
||||
type: array
|
||||
items { "$ref": "#/definitions/queue_entry" }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,11 +160,11 @@
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Worker ID"
|
||||
description: "ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Worker name"
|
||||
description: "Name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
@@ -301,6 +310,11 @@
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -363,6 +377,11 @@
|
||||
description: "The machine statistics."
|
||||
"$ref": "#/definitions/machine_stats"
|
||||
}
|
||||
tags {
|
||||
description: "New user tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import atexit
|
||||
from argparse import ArgumentParser
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from elastic.initialize import init_es_data
|
||||
from mongo.initialize import init_mongo_data
|
||||
from config import config, info
|
||||
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
|
||||
from mongo.initialize import (
|
||||
init_mongo_data,
|
||||
pre_populate_data,
|
||||
check_mongo_empty,
|
||||
get_last_server_version,
|
||||
)
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from sync import distributed_lock
|
||||
from timing_context import TimingContext
|
||||
from updates import check_updates_thread
|
||||
from utilities import json
|
||||
@@ -33,8 +41,47 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
||||
|
||||
database.initialize()
|
||||
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
# build a key that uniquely identifies specific mongo instance
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
|
||||
upgrade_monitoring = config.get(
|
||||
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
|
||||
)
|
||||
try:
|
||||
empty_es = check_elastic_empty()
|
||||
except ElasticConnectionError as err:
|
||||
if not upgrade_monitoring:
|
||||
raise
|
||||
log.error(err)
|
||||
info.es_connection_error = True
|
||||
|
||||
empty_db = check_mongo_empty()
|
||||
if (
|
||||
upgrade_monitoring
|
||||
and not empty_db
|
||||
and (info.es_connection_error or empty_es)
|
||||
and get_last_server_version() < Version("0.16.0")
|
||||
):
|
||||
log.info(f"ES database seems not migrated")
|
||||
info.missed_es_upgrade = True
|
||||
|
||||
if info.es_connection_error and not info.missed_es_upgrade:
|
||||
raise Exception(
|
||||
"Error starting server: failed connecting to ElasticSearch service"
|
||||
)
|
||||
|
||||
if not info.missed_es_upgrade:
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
|
||||
if (
|
||||
not info.missed_es_upgrade
|
||||
and empty_db
|
||||
and config.get("apiserver.pre_populate.enabled", False)
|
||||
):
|
||||
pre_populate_data()
|
||||
|
||||
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user