mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
68 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bdbcfcd8d | ||
|
|
a2e2052b30 | ||
|
|
0146ded4f4 | ||
|
|
dccf9dd8f8 | ||
|
|
7816b402bb | ||
|
|
cd4ce30f7c | ||
|
|
8c7e230898 | ||
|
|
42ba696518 | ||
|
|
3f84e60a1f | ||
|
|
baba8b5b73 | ||
|
|
77397c4f21 | ||
|
|
8678091d8f | ||
|
|
aa22170ab4 | ||
|
|
901ec37290 | ||
|
|
21f2ea8b17 | ||
|
|
8219e3d4e2 | ||
|
|
3ed71a61d5 | ||
|
|
18a88a8e8f | ||
|
|
318a72987c | ||
|
|
5ce202cc99 | ||
|
|
d09528bc26 | ||
|
|
42d2a41dbe | ||
|
|
82be1840b0 | ||
|
|
27352c5cb6 | ||
|
|
1ea6408d41 | ||
|
|
5e095af3aa | ||
|
|
ab3dceed92 | ||
|
|
3bf5126d84 | ||
|
|
ab2ab7b23a | ||
|
|
c9184d125b | ||
|
|
0c0fdb72b9 | ||
|
|
86378053d4 | ||
|
|
b1cbba0cf1 | ||
|
|
f31526042d | ||
|
|
3f8d5bc346 | ||
|
|
11d76e7d8c | ||
|
|
e76c0fbc63 | ||
|
|
fdc9956da3 | ||
|
|
f4addaa653 | ||
|
|
667964cc82 | ||
|
|
e1309e30b7 | ||
|
|
9403942ef7 | ||
|
|
84a75d9e70 | ||
|
|
c85ab66ae6 | ||
|
|
bf7f0f646b | ||
|
|
dcdf2a3d58 | ||
|
|
f8d8fc40a6 | ||
|
|
45d434a123 | ||
|
|
1834abe5bc | ||
|
|
d6321588f3 | ||
|
|
c17b10ff1d | ||
|
|
b125a56f86 | ||
|
|
c43ce3a17b | ||
|
|
b0b09616a8 | ||
|
|
ede5586ccc | ||
|
|
a1dcdffa53 | ||
|
|
35a11db58e | ||
|
|
d9bdebefc7 | ||
|
|
f29884f05a | ||
|
|
0f72d662f8 | ||
|
|
6202219034 | ||
|
|
bb3218f65d | ||
|
|
cbcaa7c789 | ||
|
|
427322a424 | ||
|
|
0e7d7d36a9 | ||
|
|
06032a6d66 | ||
|
|
b48f4eb2eb | ||
|
|
383b2666c4 |
58
README.md
58
README.md
@@ -1,12 +1,16 @@
|
||||
# Trains Server
|
||||
|
||||
## Auto-Magical Experiment Manager & Version Control for AI
|
||||
## Auto-Magical Experiment Manager & Version Control for AI - ε Devops Included!
|
||||
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://img.shields.io/badge/status-beta-yellow.svg)
|
||||
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
|
||||
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
@@ -60,14 +64,15 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch **trains-server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
|
||||
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
|
||||
## Connecting Trains to your trains-server
|
||||
|
||||
@@ -95,12 +100,32 @@ you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in yo
|
||||
for example http://localhost:8080.
|
||||
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
|
||||
|
||||
## Trains-Agent Services <a name="services"></a>
|
||||
|
||||
As of version 0.15 of **trains-server**, dockerized deployment includes a **Trains-Agent Services** container running as
|
||||
part of the docker container collection.
|
||||
|
||||
Trains-Agent Services is an extension of Trains-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Trains-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by Trains-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
You can also run the Trains-Agent Services manually, see details in [trains-agent services mode](https://github.com/allegroai/trains-agent#trains-agent-services-mode-)
|
||||
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
|
||||
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
|
||||
|
||||
## Advanced Functionality
|
||||
|
||||
**trains-server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
|
||||
* [Web login authentication](https://allegro.ai/docs/faq/faq/#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/docs/faq/faq/#watchdog)
|
||||
|
||||
## Restarting trains-server
|
||||
|
||||
@@ -149,18 +174,29 @@ To upgrade your existing **trains-server** deployment:
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
```
|
||||
|
||||
1. Configure the Trains-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, Trains-Agent Services will use the external
|
||||
public address of the **trains-server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
the Trains-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml pull
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/docs/faq/faq/#common-docker-upgrade-errors).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
|
||||
If you have any questions, look to the Trains [FAQ](https://allegro.ai/docs/faq/faq/), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
|
||||
@@ -15,6 +15,8 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@@ -38,15 +40,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -56,10 +54,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
mongo:
|
||||
|
||||
@@ -22,6 +22,9 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
@@ -37,15 +40,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -55,10 +54,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
@@ -73,6 +72,8 @@ services:
|
||||
volumes:
|
||||
- c:/opt/trains/logs:/var/log/trains
|
||||
- c:/opt/trains/data/fileserver:/mnt/fileserver
|
||||
- c:/opt/trains/config:/opt/trains/config
|
||||
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
@@ -84,7 +85,8 @@ services:
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
- mongodata:/data
|
||||
- c:/opt/trains/data/mongo/db:/data/db
|
||||
- c:/opt/trains/data/mongo/configdb:/data/configdb
|
||||
ports:
|
||||
- "27017:27017"
|
||||
|
||||
@@ -115,6 +117,3 @@ services:
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
mongodata:
|
||||
|
||||
@@ -10,6 +10,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@@ -22,8 +23,10 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||
TRAINS__apiserver__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
|
||||
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
@@ -39,15 +42,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -57,10 +56,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
@@ -75,6 +74,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
@@ -108,13 +108,43 @@ services:
|
||||
container_name: trains-webserver
|
||||
image: allegroai/trains:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
ports:
|
||||
- "8080:80"
|
||||
|
||||
agent-services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-agent-services
|
||||
image: allegroai/trains-agent-services:latest
|
||||
restart: unless-stopped
|
||||
privileged: true
|
||||
environment:
|
||||
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
|
||||
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
|
||||
TRAINS_API_HOST: http://apiserver:8008
|
||||
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
|
||||
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
|
||||
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
|
||||
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
|
||||
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
|
||||
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
|
||||
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
|
||||
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
|
||||
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
|
||||
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
|
||||
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
|
||||
TRAINS_WORKER_ID: "trains-services"
|
||||
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /opt/trains/agent:/root/.trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
@@ -26,7 +26,7 @@ The minimum recommended amount of RAM is 8GB. For example, **t3.large** or **t3a
|
||||
|
||||
To upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
### Upgrading AMIs to v0.12
|
||||
### Note on upgrading AMIs to v0.12
|
||||
|
||||
This upgrade includes the automatically updated AMI in Version 0.12. It also includes an additional REDIS docker to the **trains-server** setup.
|
||||
|
||||
@@ -50,45 +50,102 @@ To upgrade the AMI:
|
||||
|
||||
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
|
||||
|
||||
### Latest version AMI - v0.14.1 (auto update)<a name="autoupdate"></a>
|
||||
### Latest version AMI - v0.15.1 (auto update)<a name="autoupdate"></a>
|
||||
|
||||
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
|
||||
|
||||
* **eu-north-1** : ami-033fd0d9163e0a36e
|
||||
* **ap-south-1** : ami-0cdd7f9880336f9b5
|
||||
* **eu-west-3** : ami-085f508eef9f650d5
|
||||
* **eu-west-2** : ami-0936deb94a193502d
|
||||
* **eu-west-1** : ami-0c20b3620f1c6fff7
|
||||
* **ap-northeast-2** : ami-0707c9790f8a224c4
|
||||
* **ap-northeast-1** : ami-04595f0745c090328
|
||||
* **sa-east-1** : ami-03898299742d43ad4
|
||||
* **ca-central-1** : ami-0502dfea5223d572a
|
||||
* **ap-southeast-1** : ami-02aa1f9308404e464
|
||||
* **ap-southeast-2** : ami-0b66189b90df79b38
|
||||
* **eu-central-1** : ami-0eb919d8234c49cdc
|
||||
* **us-east-2** : ami-02fb63fca1b9f0d4b
|
||||
* **us-west-1** : ami-01fdda7351725a689
|
||||
* **us-west-2** : ami-004a2f40cdc095870
|
||||
* **us-east-1** : ami-0a8acd1172ffebc7e
|
||||
* **eu-north-1** : ami-0f30c84b905d354b9
|
||||
* **ap-south-1** : ami-050e7acec52c8c74e
|
||||
* **eu-west-3** : ami-03911c5b5bc77ef75
|
||||
* **eu-west-2** : ami-0a5ed8aa2573ccc70
|
||||
* **eu-west-1** : ami-0a53c65e922ec0611
|
||||
* **ap-northeast-2** : ami-08cd017a37b8e8aab
|
||||
* **ap-northeast-1** : ami-056b3ca1ad5af9322
|
||||
* **sa-east-1** : ami-01ddc9325bafb400c
|
||||
* **ca-central-1** : ami-0fc3cbbd982b18b45
|
||||
* **ap-southeast-1** : ami-04c7a358df7002ef5
|
||||
* **ap-southeast-2** : ami-0eeaf54231b4ae22a
|
||||
* **eu-central-1** : ami-00b8e44041f8175fd
|
||||
* **us-east-2** : ami-0ac7deebb3f738f6d
|
||||
* **us-west-1** : ami-06bc07deb8b8c44d6
|
||||
* **us-west-2** : ami-01ba85ffe79a422f1
|
||||
* **us-east-1** : ami-04cf5a66cb4928ac3
|
||||
|
||||
### v0.15.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0cd314e267426d1b7
|
||||
* **ap-south-1** : ami-086182cbe29151f96
|
||||
* **eu-west-3** : ami-0062366012182815b
|
||||
* **eu-west-2** : ami-022b8f2e32a9d18d0
|
||||
* **eu-west-1** : ami-0d8cf60446e09aa3d
|
||||
* **ap-northeast-2** : ami-0d4c168a815b56889
|
||||
* **ap-northeast-1** : ami-0daf7887db1053ae4
|
||||
* **sa-east-1** : ami-020a759a3ba4ff22b
|
||||
* **ca-central-1** : ami-0c10b5e04b707f3e3
|
||||
* **ap-southeast-1** : ami-0f61bb3529a165fcd
|
||||
* **ap-southeast-2** : ami-032dcdc82749c66c5
|
||||
* **eu-central-1** : ami-08f364f32d2eb3bae
|
||||
* **us-east-2** : ami-0b7efc3591803eba4
|
||||
* **us-west-1** : ami-08b2df27b0ada6faf
|
||||
* **us-west-2** : ami-0693029c4bad28816
|
||||
* **us-east-1** : ami-0200954fa9c2819ff
|
||||
|
||||
### v0.15.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0bef15c03eab64c0c
|
||||
* **ap-south-1** : ami-06ac6248e583e2cd2
|
||||
* **eu-west-3** : ami-0541d86ef47a5714e
|
||||
* **eu-west-2** : ami-01381ef4c4ed22482
|
||||
* **eu-west-1** : ami-064626a0dd38b21f1
|
||||
* **ap-northeast-2** : ami-0a2490a7a3a8aa675
|
||||
* **ap-northeast-1** : ami-063f1de819a2524b8
|
||||
* **sa-east-1** : ami-07980486741b94987
|
||||
* **ca-central-1** : ami-0ced3b8b21ded839e
|
||||
* **ap-southeast-1** : ami-0c493c5093fde8741
|
||||
* **ap-southeast-2** : ami-0320a727eccb8dc6c
|
||||
* **eu-central-1** : ami-0aa85cfc78674c526
|
||||
* **us-east-2** : ami-01791485051e1880c
|
||||
* **us-west-1** : ami-0d8eade4d5888ea73
|
||||
* **us-west-2** : ami-02ceaef72cdf60f7e
|
||||
* **us-east-1** : ami-0fc3f9d1d0eba1d62
|
||||
|
||||
### v0.14.2 (static update)
|
||||
|
||||
* **eu-north-1** : ami-006d491e9e8869248
|
||||
* **ap-south-1** : ami-0e55ec221687f98e7
|
||||
* **eu-west-3** : ami-06ad9cf3c05c83e91
|
||||
* **eu-west-2** : ami-0d05839268e748cff
|
||||
* **eu-west-1** : ami-0d14c297789ce0d7a
|
||||
* **ap-northeast-2** : ami-0d7fd775f0e76cc6f
|
||||
* **ap-northeast-1** : ami-0c0a6e1daeb3f7a9c
|
||||
* **sa-east-1** : ami-01e0c5e30e94ec887
|
||||
* **ca-central-1** : ami-07a31896832734897
|
||||
* **ap-southeast-1** : ami-0886d5b2d4b7fccd5
|
||||
* **ap-southeast-2** : ami-0397d5a2db3c356fe
|
||||
* **eu-central-1** : ami-0629f26eea22f5c17
|
||||
* **us-east-2** : ami-0499c3d7bb45a1a6e
|
||||
* **us-west-1** : ami-02fa8a961a4daf9f0
|
||||
* **us-west-2** : ami-05c711cfab4342468
|
||||
* **us-east-1** : ami-0b97d99a08012c726
|
||||
|
||||
### v0.14.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0ccdf4700a6c989ad
|
||||
* **ap-south-1** : ami-0f6a8de6441d64a68
|
||||
* **eu-west-3** : ami-0c51b9a8e7b3371cc
|
||||
* **eu-west-2** : ami-099c094598f72bcb5
|
||||
* **eu-west-1** : ami-0af20d5e4ab764212
|
||||
* **ap-northeast-2** : ami-011455e8d852e02d6
|
||||
* **ap-northeast-1** : ami-0211827ee11d6ed9c
|
||||
* **sa-east-1** : ami-07509b07aa4554dc2
|
||||
* **ca-central-1** : ami-07153c171d97e460e
|
||||
* **ap-southeast-1** : ami-042d61c497063675b
|
||||
* **ap-southeast-2** : ami-0dcf27f88bd2dd622
|
||||
* **eu-central-1** : ami-0ae29f89d9bcb1a95
|
||||
* **us-east-2** : ami-053144df2cea2bd97
|
||||
* **us-west-1** : ami-0f703537206ee05f1
|
||||
* **us-west-2** : ami-007c954572c86a583
|
||||
* **us-east-1** : ami-07c59cbc7541f58e9
|
||||
* **eu-north-1** : ami-036defe1885dced2e
|
||||
* **ap-south-1** : ami-0b403aa1da6a5dc17
|
||||
* **eu-west-3** : ami-0d30c2d330d1255c4
|
||||
* **eu-west-2** : ami-06f0e8d075e50a029
|
||||
* **eu-west-1** : ami-0da721d874f282b6d
|
||||
* **ap-northeast-2** : ami-03bffe94675dd5f8c
|
||||
* **ap-northeast-1** : ami-0f96520d646423673
|
||||
* **sa-east-1** : ami-0c2f706a3b7d97282
|
||||
* **ca-central-1** : ami-0da74525dcfd74e32
|
||||
* **ap-southeast-1** : ami-066368a21cf6d232b
|
||||
* **ap-southeast-2** : ami-0bfd09170067f7318
|
||||
* **eu-central-1** : ami-06aa99b1c41492986
|
||||
* **us-east-2** : ami-065c1880f59d03272
|
||||
* **us-west-1** : ami-0b7f6b896f5058eba
|
||||
* **us-west-2** : ami-0041e10ca68eef29a
|
||||
* **us-east-1** : ami-0b7125e4305bbd7eb
|
||||
|
||||
### v0.14.0 (static update)
|
||||
* **eu-north-1** : ami-02de71586ec496e38
|
||||
|
||||
76
docs/install_gcp.md
Normal file
76
docs/install_gcp.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Deploying Trains Server on Google Cloud Platform
|
||||
|
||||
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
|
||||
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
|
||||
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
|
||||
#### Default Trains Server Service ports
|
||||
The service port numbers on our Trains Server GCP Custom Image are:
|
||||
|
||||
- Web application: `8080`
|
||||
- API Server: `8008`
|
||||
- File Server: `8081`
|
||||
|
||||
#### Default Trains Server Storage paths
|
||||
The persistent storage configuration:
|
||||
|
||||
- MongoDB: `/opt/trains/data/mongo/`
|
||||
- ElasticSearch: `/opt/trains/data/elastic/`
|
||||
- File Server: `/mnt/fileserver/`
|
||||
|
||||
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
|
||||
|
||||
## Importing the Custom Image to your GCP account
|
||||
|
||||
In order to launch an instance using the Trains Server GCP Custom Image, you'll need to import the image to your custom images list.
|
||||
|
||||
**Note:** there's **no need** to upload the image file to Google Cloud Storage - we already provide links to image files stored in Google Storage
|
||||
|
||||
To import the image to your custom images list:
|
||||
1. In the Cloud Console, go to the [Images](https://console.cloud.google.com/compute/images) page.
|
||||
1. At the top of the page, click **Create image**.
|
||||
1. In the **Name** field, specify a unique name for the image.
|
||||
1. Optionally, specify an image family for your new image, or configure specific encryption settings for the image.
|
||||
1. Click the **Source** menu and select **Cloud Storage file**.
|
||||
1. Enter the Trains Server image bucket path (see [Trains Server GCP Custom Image](#released-versions)), for example:
|
||||
`allegro-files/trains-server/trains-server.tar.gz`
|
||||
1. Click the **Create** button to import the image. The process can take several minutes depending on the size of the boot disk image.
|
||||
|
||||
For more information see [Import the image to your custom images list](https://cloud.google.com/compute/docs/import/import-existing-image#import_image) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
|
||||
|
||||
## Launching an instance with a Custom Image
|
||||
|
||||
For instructions on launching an instance using a GCP Custom Image, see the [Manually importing virtual disks](https://cloud.google.com/compute/docs/import/import-existing-image#overview) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
|
||||
For more information on Custom Images, see [Custom Images](https://cloud.google.com/compute/docs/images#custom_images) in the Compute Engine Documentation.
|
||||
|
||||
The minimum recommended requirements for Trains Server are:
|
||||
- 2 vCPUs
|
||||
- 7.5GB RAM
|
||||
|
||||
## Upgrading
|
||||
|
||||
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
## Network and Security
|
||||
|
||||
Please make sure your instance is properly secured.
|
||||
|
||||
If not specifically set, a GCP instance will use default firewall rules that allow public access to various ports.
|
||||
If your instance is open for public access, we recommend you follow best practices for access management, including:
|
||||
- Allow access only to the specific ports used by Trains Server (see [Default Trains Server Service ports](#default-trains-server-service-ports)). Remember to allow access to port `443` if `https` access is configured for your instance.
|
||||
- Configure Trains Server to use fixed user names and passwords (see [Can I add web login authentication to trains-server?](./faq.md#web-auth))
|
||||
|
||||
## Released versions
|
||||
|
||||
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
|
||||
|
||||
### Latest version image
|
||||
|
||||
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz
|
||||
|
||||
### All released images
|
||||
|
||||
- v0.15.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-1.tar.gz
|
||||
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
|
||||
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz
|
||||
@@ -1,6 +1,6 @@
|
||||
# Launching the **trains-server** Docker in Linux or macOS
|
||||
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
|
||||
For Linux users:
|
||||
|
||||
@@ -16,20 +16,20 @@ To launch **trains-server** on Linux or macOS:
|
||||
|
||||
1. Verify the Docker CE installation. Execute the command:
|
||||
|
||||
sudo docker run hello-world
|
||||
|
||||
docker run hello-world
|
||||
|
||||
The expected is output is:
|
||||
|
||||
Hello from Docker!
|
||||
This message shows that your installation appears to be working correctly.
|
||||
To generate this message, Docker took the following steps:
|
||||
|
||||
|
||||
1. The Docker client contacted the Docker daemon.
|
||||
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
|
||||
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
|
||||
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
|
||||
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
|
||||
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
||||
sudo chmod +x /usr/local/bin/docker-compose
|
||||
@@ -42,12 +42,12 @@ To launch **trains-server** on Linux or macOS:
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
sudo service docker restart
|
||||
|
||||
|
||||
macOS:
|
||||
|
||||
|
||||
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
|
||||
sysctl -w vm.max_map_count=262144
|
||||
|
||||
|
||||
|
||||
1. Remove any previous installation of **trains-server**.
|
||||
|
||||
@@ -64,28 +64,28 @@ To launch **trains-server** on Linux or macOS:
|
||||
sudo mkdir -p /opt/trains/logs
|
||||
sudo mkdir -p /opt/trains/config
|
||||
sudo mkdir -p /opt/trains/data/fileserver
|
||||
|
||||
|
||||
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
|
||||
|
||||
|
||||
1. Grant access to the Dockers.
|
||||
|
||||
Linux:
|
||||
|
||||
sudo chown -R 1000:1000 /opt/trains
|
||||
|
||||
|
||||
macOS:
|
||||
|
||||
|
||||
sudo chown -R $(whoami):staff /opt/trains
|
||||
|
||||
1. Download the **trains-server** docker-compose YAML file.
|
||||
|
||||
cd /opt/trains
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
|
||||
|
||||
1. Run `docker-compose` with the downloaded configuration file.
|
||||
|
||||
sudo docker-compose -f docker-compose.yml up
|
||||
|
||||
docker-compose -f docker-compose.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
|
||||
|
||||
* Web server on port `8080`
|
||||
@@ -94,4 +94,4 @@ To launch **trains-server** on Linux or macOS:
|
||||
|
||||
## Next Step
|
||||
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
@@ -46,7 +50,23 @@ class BasicConfig:
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
@staticmethod
|
||||
def _read_extra_env_config_values():
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _read_env_paths(key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
return
|
||||
@@ -64,12 +84,17 @@ class BasicConfig:
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
|
||||
self._config = reduce(
|
||||
lambda config, path: ConfigTree.merge_configs(
|
||||
config, self._read_recursive(path, verbose=verbose), copy_trees=True
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
[self.folder] + extra_config_paths,
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
download {
|
||||
# Add response headers requesting no caching for served files
|
||||
disable_browser_caching: false
|
||||
|
||||
# Cache timeout to be set for downloaded files
|
||||
cache_timeout_sec: 300
|
||||
}
|
||||
|
||||
cors {
|
||||
|
||||
@@ -10,12 +10,14 @@ from flask_cors import CORS
|
||||
|
||||
from config import config
|
||||
|
||||
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
if os.environ.get("TRAINS_UPLOAD_FOLDER"):
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER")
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get("fileserver.download.cache_timeout_sec", 5 * 60)
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
@@ -57,12 +59,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--upload-folder",
|
||||
"-u",
|
||||
default="/mnt/fileserver",
|
||||
default=DEFAULT_UPLOAD_FOLDER,
|
||||
help="Upload folder (default %(default)s)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
if app.config.get("UPLOAD_FOLDER") is None:
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
|
||||
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "2.7.0"
|
||||
__version__ = "2.9.0"
|
||||
|
||||
@@ -47,6 +47,7 @@ _error_codes = {
|
||||
128: ('invalid_task_output', 'invalid task output'),
|
||||
129: ('task_publish_in_progress', 'Task publish in progress'),
|
||||
130: ('task_not_found', 'task not found'),
|
||||
131: ('events_not_added', 'events not added'),
|
||||
|
||||
# Models
|
||||
200: ('model_error', 'general task error'),
|
||||
|
||||
@@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apierrors import errors
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
@@ -206,10 +207,10 @@ class DomainField(fields.StringField):
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
from mongoengine.base import BaseDocument
|
||||
|
||||
from apimodels import DictField
|
||||
from apimodels import DictField, ListField
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
@@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
'inc',
|
||||
'dec',
|
||||
'push',
|
||||
'push_all',
|
||||
'pop',
|
||||
'pull',
|
||||
'pull_all',
|
||||
'add_to_set',
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split('__')
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == 'set':
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == 'unset':
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return '.'.join(parts), cls._normalize_mongo_value(value)
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
@@ -62,3 +63,7 @@ class PagedRequest(models.Base):
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from typing import Sequence
|
||||
from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
from utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=10000)
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
@@ -21,7 +23,7 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
items_types=str, validators=[Length(minimum_value=1, maximum_value=10)]
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +42,19 @@ class DebugImagesRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
asc = auto()
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
|
||||
@@ -6,6 +6,10 @@ from apimodels.base import UpdateResponse
|
||||
from apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
projects = fields.ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
|
||||
11
server/apimodels/organization.py
Normal file
11
server/apimodels/organization.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
@@ -1,5 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apimodels import ListField
|
||||
from apimodels.organization import TagsRequest
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
project = fields.StringField()
|
||||
@@ -10,7 +13,5 @@ class GetHyperParamReq(ProjectReq):
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apimodels import DictField, ListField
|
||||
from apimodels.base import UpdateResponse
|
||||
@@ -92,6 +94,10 @@ class PingRequest(TaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
@@ -99,7 +105,10 @@ class CloneRequest(TaskRequest):
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_hyperparams = DictField()
|
||||
new_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
@@ -109,3 +118,76 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsResponse(models.Base):
|
||||
added = ListField([str])
|
||||
updated = ListField([str])
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
@@ -13,13 +12,14 @@ from jsonmodels.fields import (
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import make_default, ListField, EnumField
|
||||
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@@ -61,26 +61,21 @@ class IdNameEntry(Base):
|
||||
name = StringField()
|
||||
|
||||
|
||||
class WorkerEntry(Base):
|
||||
class WorkerEntry(Base, JsonSerializableMixin):
|
||||
key = StringField() # not required due to migration issues
|
||||
id = StringField(required=True)
|
||||
user = EmbeddedField(IdNameEntry)
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**json.loads(s))
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
|
||||
@@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
|
||||
from database.model.task.metrics import MetricEventStats
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
@@ -45,17 +43,10 @@ class MetricScrollState(Base):
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base):
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
|
||||
def to_json(self):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**loads(s))
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
@@ -65,7 +56,12 @@ class DebugImagesResult(object):
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
STATE_EXPIRATION_SECONDS = 3600
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@property
|
||||
def _max_workers(self):
|
||||
@@ -76,7 +72,7 @@ class DebugImagesIterator:
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.STATE_EXPIRATION_SECONDS,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
@@ -92,27 +88,31 @@ class DebugImagesIterator:
|
||||
if not self.es.indices.exists(es_index):
|
||||
return DebugImagesResult()
|
||||
|
||||
unique_metrics = set(metrics)
|
||||
state = self.cache_manager.get_state(state_id) if state_id else None
|
||||
if not state:
|
||||
state = DebugImageEventsScrollState(
|
||||
id=database.utils.id(),
|
||||
metrics=self._init_metric_states(es_index, list(unique_metrics)),
|
||||
)
|
||||
else:
|
||||
state_metrics = set((m.task, m.name) for m in state.metrics)
|
||||
if state_metrics != unique_metrics:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"while getting debug images events", scroll_id=state_id
|
||||
)
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state)
|
||||
for metric_state in state.metrics:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
try:
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
@@ -125,10 +125,8 @@ class DebugImagesIterator:
|
||||
state.metrics,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.cache_manager.set_state(state)
|
||||
|
||||
return res
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||
@@ -210,7 +208,11 @@ class DebugImagesIterator:
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
@@ -253,7 +255,7 @@ class DebugImagesIterator:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -300,6 +302,7 @@ class DebugImagesIterator:
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
@@ -370,7 +373,7 @@ class DebugImagesIterator:
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
@@ -389,7 +392,7 @@ class DebugImagesIterator:
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
@@ -16,12 +15,14 @@ import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from bll.event.event_metrics import EventMetrics, EventType
|
||||
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@@ -29,15 +30,9 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult(object):
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.ib(factory=list)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@@ -47,12 +42,28 @@ class EventBLL(object):
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
return self._metrics
|
||||
|
||||
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
|
||||
@staticmethod
|
||||
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not task_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
@@ -62,19 +73,34 @@ class EventBLL(object):
|
||||
task_last_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
|
||||
errors_per_type = defaultdict(int)
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
if "type" not in event:
|
||||
raise errors.BadRequest("Event must have a 'type' field", event=event)
|
||||
event_type = event.get("type")
|
||||
if event_type is None:
|
||||
errors_per_type["Event must have a 'type' field"] += 1
|
||||
continue
|
||||
|
||||
event_type = event["type"].replace(" ", "_")
|
||||
event_type = event_type.replace(" ", "_")
|
||||
if event_type not in EVENT_TYPES:
|
||||
raise errors.BadRequest(
|
||||
"Invalid event type {}".format(event_type),
|
||||
event=event,
|
||||
types=EVENT_TYPES,
|
||||
)
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
|
||||
@@ -110,7 +136,6 @@ class EventBLL(object):
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
@@ -120,89 +145,74 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is not None:
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
)
|
||||
else:
|
||||
es_action["_routing"] = task_id
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
if task_ids:
|
||||
# verify task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
extra_msg = None
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
extra_msg = "or task published"
|
||||
res = Task.objects(query).only("id")
|
||||
if len(res) < len(task_ids):
|
||||
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
extra_msg, company=company_id, ids=invalid_task_ids
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
errors_in_bulk = []
|
||||
added = 0
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_in_bulk.append(info)
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
|
||||
# who's events were successful
|
||||
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
|
||||
added = min(added, len(actions))
|
||||
|
||||
return added, errors_in_bulk
|
||||
if not added:
|
||||
raise errors.bad_request.EventsNotAdded(**errors_per_type)
|
||||
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
@@ -220,9 +230,25 @@ class EventBLL(object):
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
last_event = last_events[metric_hash][variant_hash]
|
||||
event_iter = event.get("iter", 0)
|
||||
event_timestamp = event.get("timestamp", 0)
|
||||
value = event.get("value")
|
||||
if value is not None and (
|
||||
(event_iter, event_timestamp)
|
||||
>= (
|
||||
last_event.get("iter", event_iter),
|
||||
last_event.get("timestamp", event_timestamp),
|
||||
)
|
||||
):
|
||||
event_data = {
|
||||
k: event[k]
|
||||
for k in ("value", "metric", "variant", "iter", "timestamp")
|
||||
if k in event
|
||||
}
|
||||
event_data["min_value"] = min(value, last_event.get("min_value", value))
|
||||
event_data["max_value"] = max(value, last_event.get("max_value", value))
|
||||
last_events[metric_hash][variant_hash] = event_data
|
||||
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
"""
|
||||
@@ -265,7 +291,13 @@ class EventBLL(object):
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=["value", "metric", "variant"],
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -290,6 +322,9 @@ class EventBLL(object):
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -310,14 +345,9 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, scroll="1h", routing=task_id
|
||||
)
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
@@ -345,7 +375,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -361,7 +391,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -381,6 +411,9 @@ class EventBLL(object):
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -390,13 +423,11 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
must = []
|
||||
if last_iterations_per_plot is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
@@ -419,32 +450,41 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(tasks)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
# scroll id may be missing when queering a totally empty DB
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
@@ -457,6 +497,8 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
@@ -470,20 +512,16 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
@@ -502,27 +540,23 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
@@ -558,7 +592,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@@ -590,14 +624,14 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
@@ -627,7 +661,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@@ -674,7 +708,7 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@@ -695,7 +729,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -705,7 +739,7 @@ class EventBLL(object):
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -727,8 +761,6 @@ class EventBLL(object):
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
)
|
||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Callable, Iterable
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
@@ -16,7 +15,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -30,14 +29,18 @@ class EventType(Enum):
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 50
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
MAX_METRICS_COUNT = 100
|
||||
MAX_VARIANTS_COUNT = 100
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@property
|
||||
def _max_concurrency(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
@@ -51,15 +54,48 @@ class EventMetrics:
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
return self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=[task_id],
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average,
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
@@ -72,159 +108,115 @@ class EventMetrics:
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||
raise errors.BadRequest(
|
||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
|
||||
len(task_ids),
|
||||
)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
override_projection=("id", "name", "company"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
ret = self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average_per_task,
|
||||
)
|
||||
companies = {t.company for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
for metric_data in ret.values():
|
||||
for variant_data in metric_data.values():
|
||||
for task_id, task_data in variant_data.items():
|
||||
task_data["name"] = task_name_by_id[task_id]
|
||||
|
||||
return ret
|
||||
|
||||
TaskMetric = Tuple[str, str, str]
|
||||
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _split_metrics_by_max_aggs_count(
|
||||
self, task_metrics: Sequence[TaskMetric]
|
||||
) -> Iterable[Sequence[TaskMetric]]:
|
||||
"""
|
||||
Return task metrics in groups where amount of task metrics in each group
|
||||
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
|
||||
variants while always preserving all their tasks in the same group
|
||||
"""
|
||||
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield task_metrics
|
||||
return
|
||||
|
||||
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
|
||||
groups = []
|
||||
for group in tm_grouped.values():
|
||||
groups.append(group)
|
||||
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield list(itertools.chain(*groups))
|
||||
groups = []
|
||||
|
||||
if groups:
|
||||
yield list(itertools.chain(*groups))
|
||||
|
||||
return
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
task_ids: Sequence[str],
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
get_func: Callable[
|
||||
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
|
||||
],
|
||||
) -> dict:
|
||||
"""
|
||||
Group metrics per interval length and execute get_func for each group in parallel
|
||||
:param company_id: id of the company
|
||||
:params task_ids: ids of the tasks to collect data for
|
||||
:param samples: maximum number of samples per metric
|
||||
:param get_func: callable that given metric names for the same interval
|
||||
performs histogram aggregation for the metrics and return the aggregated data
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
intervals = self._get_metric_intervals(
|
||||
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
es_index=es_index,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
intervals = list(
|
||||
itertools.chain.from_iterable(
|
||||
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
|
||||
for i, tms in intervals
|
||||
)
|
||||
)
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
|
||||
intervals,
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return ret
|
||||
return res
|
||||
|
||||
def _get_metric_intervals(
|
||||
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return metric variants grouped by interval value with 10% rounding
|
||||
For samples==0 return empty list
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
default_intervals = [(1, [])]
|
||||
if not samples:
|
||||
return default_intervals
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -233,88 +225,75 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return default_intervals
|
||||
return []
|
||||
|
||||
intervals = [
|
||||
(
|
||||
task["key"],
|
||||
metric["key"],
|
||||
variant["key"],
|
||||
self._calculate_metric_interval(variant, samples),
|
||||
)
|
||||
for task in aggs_result["tasks"]["buckets"]
|
||||
for metric in task["metrics"]["buckets"]
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
metric_intervals = []
|
||||
upper_border = 0
|
||||
interval_metrics = None
|
||||
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
|
||||
if not interval_metrics or interval > upper_border:
|
||||
interval_metrics = []
|
||||
metric_intervals.append((interval, interval_metrics))
|
||||
upper_border = interval + int(interval * 0.1)
|
||||
interval_metrics.append((task, metric, variant))
|
||||
|
||||
return metric_intervals
|
||||
|
||||
@staticmethod
|
||||
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(metric_variant, "count/value")
|
||||
if not count or count < samples:
|
||||
return 1
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(metric_variant, "min_index/value", default=0)
|
||||
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
|
||||
return max(1, int(max_index - min_index + 1) // samples)
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
max(1, int(max_index - min_index + 1) // samples),
|
||||
samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
Note: the function works with a single task only
|
||||
"""
|
||||
|
||||
assert len(task_ids) == 1
|
||||
interval, task_metrics = metrics_interval
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
@@ -335,61 +314,6 @@ class EventMetrics:
|
||||
]
|
||||
return metrics
|
||||
|
||||
def _get_scalar_average_per_task(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, task_metrics = metrics_interval
|
||||
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {
|
||||
"field": "task",
|
||||
"size": self.MAX_TASKS_COUNT,
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
task["key"]: key.get_iterations_data(task)
|
||||
for task in variant["tasks"]["buckets"]
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
@@ -398,69 +322,55 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_metrics_and_tasks(
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
es_index: str,
|
||||
aggs: dict,
|
||||
task_ids: Sequence[str],
|
||||
task_metrics: Sequence[TaskMetric],
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
if task_metrics:
|
||||
condition = {
|
||||
"should": [
|
||||
self._build_metric_terms(task, metric, variant)
|
||||
for task, metric, variant in task_metrics
|
||||
]
|
||||
}
|
||||
else:
|
||||
condition = {"must": [{"terms": {"task": task_ids}}]}
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"bool": condition},
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
|
||||
"""
|
||||
Build query term for a metric + variant
|
||||
"""
|
||||
return {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence[Tuple]:
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [(tid, []) for tid in task_ids]
|
||||
return {}
|
||||
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_concurrency) as pool:
|
||||
with ThreadPoolExecutor(self._max_concurrency) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
@@ -488,7 +398,7 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
||||
113
server/bll/event/log_events_iterator.py
Normal file
113
server/bll/event/log_events_iterator.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
es_index,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
if not hits or len(hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
|
||||
already_present_ids = set(ev["_id"] for ev in events)
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[
|
||||
*events,
|
||||
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
|
||||
],
|
||||
hits_total,
|
||||
)
|
||||
@@ -4,7 +4,7 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from apimodels import StringEnum
|
||||
from utilities.stringenum import StringEnum
|
||||
from bll.util import extract_properties_to_lists
|
||||
from config import config
|
||||
|
||||
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": interval,
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": interval,
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
|
||||
18
server/bll/model/__init__.py
Normal file
18
server/bll/model/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from database.model.model import Model
|
||||
from database.utils import get_company_or_none_constraint
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
193
server/bll/organization/__init__.py
Normal file
193
server/bll/organization/__init__.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
self.redis = redis
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
|
||||
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
):
|
||||
"""
|
||||
Project None means 'from all company projects'
|
||||
The key is built in the way that scanning company keys for 'all company projects'
|
||||
will not return the keys related to the particular company projects and vice versa.
|
||||
So that we can have a fine grain control on what redis keys to invalidate
|
||||
"""
|
||||
filter_str = None
|
||||
if filter_:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
Return the dictionary of tags per tags field name
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
redis_keys = [
|
||||
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
|
||||
for f in fields
|
||||
]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, project, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
fields = [
|
||||
field
|
||||
for field, update in (
|
||||
(self._tags_field, tags),
|
||||
(self._system_tags_field, system_tags),
|
||||
)
|
||||
if update is not None
|
||||
]
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
)
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_, project=project
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
1
server/bll/project/__init__.py
Normal file
1
server/bll/project/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .project_bll import ProjectBLL
|
||||
33
server/bll/project/project_bll.py
Normal file
33
server/bll/project/project_bll.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
@@ -18,7 +18,6 @@ log = config.logger(__file__)
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
DOC_TYPE = "metrics"
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
@@ -66,7 +65,6 @@ class QueueMetrics:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type=self.EsKeys.DOC_TYPE,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
@@ -93,7 +91,6 @@ class QueueMetrics:
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
doc_type=self.EsKeys.DOC_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -109,7 +106,7 @@ class QueueMetrics:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
from typing import Optional, TypeVar, Generic, Type
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
import database
|
||||
from timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _do_nothing(_: T):
|
||||
return
|
||||
|
||||
|
||||
class RedisCacheManager(Generic[T]):
|
||||
"""
|
||||
Class for store/retreive of state objects from redis
|
||||
Class for store/retrieve of state objects from redis
|
||||
|
||||
self.state_class - class of the state
|
||||
self.redis - instance of redis
|
||||
@@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]):
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
):
|
||||
"""
|
||||
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
||||
If no then create a new one with randomly generated id
|
||||
Yield the state and write it back to redis once the user code block exits
|
||||
:param state_id: id of the state to retrieve
|
||||
:param init_state: user callback to init the newly created state
|
||||
If not passed then no init except for the id generation is done
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
self.set_state(state)
|
||||
|
||||
@@ -19,7 +19,7 @@ from config.info import get_deployment_type
|
||||
from database.model import Company, User
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
from utilities.json import dumps
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from version import __version__ as current_version
|
||||
@@ -237,7 +237,6 @@ class StatisticsReporter:
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -280,7 +279,7 @@ class StatisticsReporter:
|
||||
]
|
||||
return {
|
||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||
for group in Task.aggregate(*pipeline)
|
||||
for group in Task.aggregate(pipeline)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,5 +4,4 @@ from .utils import (
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
ParameterKeyEscaper,
|
||||
)
|
||||
|
||||
229
server/bll/task/hyperparams.py
Normal file
229
server/bll/task/hyperparams.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import (
|
||||
HyperParamKey,
|
||||
HyperParamItem,
|
||||
ReplaceHyperparams,
|
||||
Configuration,
|
||||
)
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
class HyperParams:
|
||||
_properties_section = "properties"
|
||||
|
||||
@classmethod
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_params_list(
|
||||
cls, items: Dict[str, Dict[str, ParamsItem]]
|
||||
) -> Sequence[dict]:
|
||||
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
||||
return [
|
||||
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _normalize_params(cls, params: Sequence) -> bool:
|
||||
"""
|
||||
Lower case properties section and return True if it is the only section
|
||||
"""
|
||||
for p in params:
|
||||
if p.section.lower() == cls._properties_section:
|
||||
p.section = cls._properties_section
|
||||
|
||||
return all(p.section == cls._properties_section for p in params)
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{section}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
||||
return {
|
||||
ParameterKeyEscaper.escape(section): {
|
||||
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
||||
for param in params
|
||||
}
|
||||
for section, params in sections.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configurations(
|
||||
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
only = ["id"]
|
||||
if names:
|
||||
only.extend(
|
||||
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
||||
)
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"configuration": [
|
||||
c.to_proper_dict()
|
||||
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
||||
]
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
) -> Dict[str, list]:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration=Sequence[str]
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@staticmethod
|
||||
def _get_task_for_update(
|
||||
company: str, id: str, allow_all_statuses: bool = False
|
||||
) -> Task:
|
||||
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
if task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apierrors import errors
|
||||
from bll.task import ChangeStatusRequest
|
||||
from config import config
|
||||
from database.model.task.task import TaskStatus, Task
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class NonResponsiveTasksWatchdog:
|
||||
threads = ThreadsManager()
|
||||
|
||||
class _Settings:
|
||||
"""
|
||||
Retrieves watchdog settings from the config file
|
||||
The properties are not cached so that the updates in
|
||||
the config file are reflected
|
||||
"""
|
||||
|
||||
_prefix = "services.tasks.non_responsive_tasks_watchdog"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return config.get(f"{self._prefix}.enabled", True)
|
||||
|
||||
@property
|
||||
def watch_interval_sec(self):
|
||||
return config.get(f"{self._prefix}.watch_interval_sec", 900)
|
||||
|
||||
@property
|
||||
def threshold_sec(self):
|
||||
return config.get(f"{self._prefix}.threshold_sec", 7200)
|
||||
|
||||
settings = _Settings()
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
stopped = cls.cleanup_tasks(
|
||||
threshold_sec=cls.settings.threshold_sec
|
||||
)
|
||||
log.info(f"{stopped} non-responsive tasks stopped")
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
sleep(watch_interval)
|
||||
|
||||
@classmethod
|
||||
def cleanup_tasks(cls, threshold_sec):
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(seconds=threshold_sec)
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
|
||||
"id", "name", "status", "project", "last_update"
|
||||
)
|
||||
)
|
||||
log.info(f"{len(tasks)} non-responsive tasks found")
|
||||
if not tasks:
|
||||
return 0
|
||||
|
||||
err_count = 0
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
try:
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
except errors.bad_request.FailedChangingTaskStatus:
|
||||
err_count += 1
|
||||
|
||||
return len(tasks) - err_count
|
||||
201
server/bll/task/param_utils.py
Normal file
201
server/bll/task/param_utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import dpath
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.task.task import Task
|
||||
from tools import safe_get
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
hyperparams_default_section = "Args"
|
||||
hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
if default_section is None:
|
||||
return None, full_name
|
||||
|
||||
section, _, name = full_name.partition("/")
|
||||
if section != tf_define_section:
|
||||
return default_section, full_name
|
||||
|
||||
if not name:
|
||||
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
|
||||
return section, name
|
||||
|
||||
|
||||
def _get_full_param_name(param: dict) -> str:
|
||||
section = param.get("section")
|
||||
if section != tf_define_section:
|
||||
return param["name"]
|
||||
|
||||
return "/".join((section, param["name"]))
|
||||
|
||||
|
||||
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
removed = 0
|
||||
if not data:
|
||||
return removed
|
||||
|
||||
if with_sections:
|
||||
for section, section_data in list(data.items()):
|
||||
removed += _remove_legacy_params(section_data)
|
||||
if not section_data:
|
||||
"""If section is empty after removing legacy params then delete it"""
|
||||
del data[section]
|
||||
else:
|
||||
for key, param in list(data.items()):
|
||||
if param.get("type") == hyperparams_legacy_type:
|
||||
removed += 1
|
||||
del data[key]
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
|
||||
return [
|
||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||
]
|
||||
|
||||
|
||||
def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
"""
|
||||
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
previous_data = previous_task.to_proper_dict().get(new_params_field)
|
||||
removed = _remove_legacy_params(
|
||||
previous_data, with_sections=default_section is not None
|
||||
)
|
||||
if not legacy_params and not removed:
|
||||
# if we only need to delete legacy fields from the db
|
||||
# but they are not there then there is no point to proceed
|
||||
continue
|
||||
|
||||
fields_update = {new_params_field: previous_data}
|
||||
params_unprepare_from_saved(fields_update)
|
||||
fields.update(fields_update)
|
||||
|
||||
for full_name, value in legacy_params.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_params_field, section, name)))
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
"""
|
||||
Unescape all section and param names for hyper params and configuration
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
)
|
||||
|
||||
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
for prefix in (
|
||||
"hyperparams.",
|
||||
"-hyperparams.",
|
||||
"configuration.",
|
||||
"-configuration.",
|
||||
):
|
||||
paths = [
|
||||
_process_path(path) if path.startswith(prefix) else path for path in paths
|
||||
]
|
||||
return paths
|
||||
@@ -1,10 +1,11 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from random import random
|
||||
from time import sleep
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
|
||||
import dpath
|
||||
import pymongo.results
|
||||
import six
|
||||
from mongoengine import Q
|
||||
@@ -14,6 +15,7 @@ import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@@ -27,25 +29,38 @@ from database.model.task.task import (
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
external_task_types,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
threads = ThreadsManager("TaskBLL")
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
@@ -70,25 +85,24 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
required_dataset=None,
|
||||
only_fields=None,
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
only_fields = [only_fields]
|
||||
else:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
qs = Task.objects(id=task_id, company=company_id)
|
||||
if only_fields:
|
||||
qs = (
|
||||
qs.only(only_fields)
|
||||
if isinstance(only_fields, string_types)
|
||||
else qs.only(*only_fields)
|
||||
)
|
||||
qs = qs.only(
|
||||
"status", "input"
|
||||
) # make sure all fields we rely on here are also returned
|
||||
task = qs.first()
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -96,17 +110,12 @@ class TaskBLL(object):
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
if required_dataset and required_dataset not in (
|
||||
entry.dataset for entry in task.input.view.entries
|
||||
):
|
||||
raise errors.bad_request.InvalidId(
|
||||
"not in input view", dataset=required_dataset
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(company_id, task_ids, only=None, allow_public=False):
|
||||
def assert_exists(
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
ids = set(task_ids)
|
||||
@@ -117,14 +126,13 @@ class TaskBLL(object):
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
res = q.only(*only)
|
||||
count = len(res)
|
||||
else:
|
||||
count = q.count()
|
||||
res = q.first()
|
||||
if count != len(ids):
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
return res
|
||||
|
||||
if return_tasks:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
@@ -166,17 +174,32 @@ class TaskBLL(object):
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
execution_overrides["parameters"] = {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
@@ -203,27 +226,59 @@ class TaskBLL(object):
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task):
|
||||
assert isinstance(task, Task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project and not Project.get_for_writing(
|
||||
company=task.company, id=task.project
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
cls.validate_execution_model(task)
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
@@ -263,7 +318,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(*pipeline)
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
@@ -312,10 +367,12 @@ class TaskBLL(object):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
if path[-1] == "value":
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
@@ -327,7 +384,7 @@ class TaskBLL(object):
|
||||
|
||||
metric_stats = {
|
||||
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
|
||||
)
|
||||
for metric_key, metric_data in last_events.items()
|
||||
}
|
||||
@@ -575,81 +632,35 @@ class TaskBLL(object):
|
||||
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start_non_responsive_tasks_watchdog(cls):
|
||||
log = config.logger("non_responsive_tasks_watchdog")
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
|
||||
)
|
||||
)
|
||||
watch_interval = config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec", 900
|
||||
)
|
||||
sleep(watch_interval)
|
||||
while not ThreadsManager.terminating:
|
||||
try:
|
||||
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(
|
||||
status__in=relevant_status, last_update__lt=ref_time
|
||||
).only("id", "name", "status", "project", "last_update")
|
||||
)
|
||||
|
||||
if tasks:
|
||||
|
||||
log.info(f"Stopping {len(tasks)} non-responsive tasks")
|
||||
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
|
||||
log.info(f"Done")
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
|
||||
sleep(watch_interval)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_execution_parameters(
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[str]]:
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"execution.parameters": {"$exists": True, "$gt": {}},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
|
||||
{"$unwind": "$parameters"},
|
||||
{"$group": {"_id": "$parameters.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
@@ -666,7 +677,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(*pipeline), None)
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
@@ -675,7 +686,12 @@ class TaskBLL(object):
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r["_id"])
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import TypeVar, Callable, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from boltons.dictutils import OneToOne
|
||||
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
@@ -172,26 +171,3 @@ def split_by(
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
|
||||
|
||||
class ParameterKeyEscaper:
|
||||
_mapping = OneToOne({".": "%2E", "$": "%24"})
|
||||
|
||||
@classmethod
|
||||
def escape(cls, value):
|
||||
""" Quote a parameter key """
|
||||
value = value.strip().replace("%", "%%")
|
||||
for c, r in cls._mapping.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _unescape(cls, value):
|
||||
for c, r in cls._mapping.inv.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def unescape(cls, value):
|
||||
""" Unquote a quoted parameter key """
|
||||
return "%".join(map(cls._unescape, value.split("%%")))
|
||||
|
||||
@@ -35,14 +35,21 @@ class SetFieldsResolver:
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = set_fields
|
||||
self.fields = {
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
|
||||
@@ -21,6 +21,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
@@ -33,8 +34,8 @@ log = config.logger(__file__)
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es if es is not None else es_factory.connect("workers")
|
||||
self.redis = redis if redis is not None else redman.connection("workers")
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@property
|
||||
@@ -49,6 +50,7 @@ class WorkerBLL:
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -58,6 +60,7 @@ class WorkerBLL:
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -91,6 +94,7 @@ class WorkerBLL:
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
@@ -113,12 +117,15 @@ class WorkerBLL:
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -129,6 +136,9 @@ class WorkerBLL:
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
@@ -146,6 +156,7 @@ class WorkerBLL:
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
@@ -160,6 +171,12 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
@@ -223,7 +240,7 @@ class WorkerBLL:
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(*projection)
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
@@ -369,7 +386,6 @@ class WorkerBLL:
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type="stat",
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
|
||||
@@ -25,7 +25,6 @@ class WorkerStats:
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ class WorkerStats:
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]:
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
@@ -87,7 +86,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{request.interval}s",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
@@ -216,7 +215,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -15,7 +16,7 @@ from pyparsing import (
|
||||
|
||||
DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';'
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
@@ -57,7 +58,7 @@ class BasicConfig:
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
@@ -77,7 +78,7 @@ class BasicConfig:
|
||||
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
|
||||
]
|
||||
if invalid:
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
|
||||
return [path for path in paths if path.is_dir()]
|
||||
|
||||
def _load(self, verbose=True):
|
||||
|
||||
@@ -26,6 +26,17 @@
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
# not including the pre_populate
|
||||
db_init_timout: 120
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
@@ -34,11 +45,16 @@
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_file: "/path/to/export.zip"
|
||||
fail_on_error: false
|
||||
elastic {
|
||||
probing {
|
||||
# settings for inital probing of elastic connection
|
||||
max_retries: 4
|
||||
timeout: 30
|
||||
}
|
||||
upgrade_monitoring {
|
||||
v16_migration_verification: true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
hosts: [{host: "127.0.0.1", port: 9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9200}]
|
||||
hosts: [{host:"127.0.0.1", port:9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
|
||||
@@ -13,17 +13,21 @@
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
apiserver {
|
||||
role: "system"
|
||||
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
|
||||
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
|
||||
}
|
||||
webserver {
|
||||
role: "system"
|
||||
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
tests {
|
||||
role: "user"
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
16
server/config/default/services/auth.conf
Normal file
16
server/config/default/services/auth.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -6,4 +6,8 @@ ignore_iteration {
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
}
|
||||
|
||||
3
server/config/default/services/organization.conf
Normal file
3
server/config/default/services/organization.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
8
server/config/default/services/projects.conf
Normal file
8
server/config/default/services/projects.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured_order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
||||
@@ -1,4 +1,6 @@
|
||||
non_responsive_tasks_watchdog {
|
||||
enabled: true
|
||||
|
||||
# In-progress tasks older than this value in seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
|
||||
@@ -41,3 +41,7 @@ def get_deployment_type() -> str:
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
es_connection_error = False
|
||||
|
||||
@@ -79,6 +79,10 @@ def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_hosts():
|
||||
return [entry.host for entry in get_entries()]
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from mongoengine import (
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
@@ -125,17 +128,39 @@ def contains_empty_key(d):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
class DictValidationMixin:
|
||||
"""
|
||||
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||
"""
|
||||
|
||||
def _safe_validate(self: DictField, value):
|
||||
if not isinstance(value, dict):
|
||||
self.error("Only dictionaries may be used in a DictField")
|
||||
|
||||
if key_not_string(value):
|
||||
msg = "Invalid dictionary key - documents must have only string keys"
|
||||
self.error(msg)
|
||||
|
||||
if key_starts_with_dollar(value):
|
||||
self.error(
|
||||
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||
)
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
|
||||
class SafeMapField(MapField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
@@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
@@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
def key(v): return str(itemgetter(self._ordering)(v))
|
||||
|
||||
def key(v):
|
||||
return str(itemgetter(self._ordering)(v))
|
||||
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
@@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
@@ -43,6 +45,7 @@ class Role(object):
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
last_used = DateTimeField()
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
@@ -34,7 +35,12 @@ class AuthDocument(Document):
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
||||
def to_proper_dict(
|
||||
self: Union["ProperDictMixin", Document],
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
@@ -71,6 +77,8 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -91,11 +99,48 @@ class GetMixin(PropsMixin):
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_next = _default
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
||||
cls: Union["GetMixin", Document],
|
||||
company,
|
||||
id,
|
||||
*,
|
||||
_only=None,
|
||||
include_public=False,
|
||||
**kwargs,
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
@@ -162,17 +207,7 @@ class GetMixin(PropsMixin):
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -216,6 +251,47 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
is included.
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
@@ -272,6 +348,20 @@ class GetMixin(PropsMixin):
|
||||
return []
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
@@ -409,7 +499,27 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
@@ -429,7 +539,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
@@ -437,13 +549,14 @@ class GetMixin(PropsMixin):
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
if only:
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
@@ -460,6 +573,8 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
the None values are sorted in the end regardless of the sorting order.
|
||||
If the first order field is a user defined parameter (either from execution.parameters,
|
||||
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
@@ -477,7 +592,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
@@ -500,20 +617,29 @@ class GetMixin(PropsMixin):
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
v
|
||||
for k, v in cls._field_collation_overrides.items()
|
||||
if order_field.startswith(k)
|
||||
)
|
||||
if collation_override:
|
||||
query_sets = [
|
||||
qs.collation(collation=collation_override) for qs in query_sets
|
||||
]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if only:
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*only) for qs in query_sets]
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields())
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
@@ -524,7 +650,8 @@ class GetMixin(PropsMixin):
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
@@ -593,7 +720,13 @@ class UpdateMixin(object):
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
||||
def safe_update(
|
||||
cls: Union["UpdateMixin", Document],
|
||||
company_id,
|
||||
id,
|
||||
partial_update_dict,
|
||||
injected_update=None,
|
||||
):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
@@ -610,7 +743,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
|
||||
@classmethod
|
||||
def aggregate(
|
||||
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
|
||||
cls: Union["DbModelMixin", Document],
|
||||
pipeline: Sequence[dict],
|
||||
allow_disk_use=None,
|
||||
**kwargs,
|
||||
) -> CommandCursor:
|
||||
"""
|
||||
Aggregate objects of this document class according to the provided pipeline.
|
||||
@@ -625,7 +761,32 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
)
|
||||
return cls.objects.aggregate(*pipeline, **kwargs)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
@@ -647,5 +808,5 @@ def validate_id(cls, company, **kwargs):
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeDictField
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
@@ -12,46 +13,63 @@ from database.model.user import User
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
'name': '%s.model.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$parent',
|
||||
'$task',
|
||||
'$project',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'parent': 5,
|
||||
'task': 3,
|
||||
'project': 3,
|
||||
}
|
||||
}
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"parent": 5,
|
||||
"task": 3,
|
||||
"project": 3,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"user",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field='Model', required=False)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default='', user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default="", user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from mongoengine import MapField, IntField
|
||||
from database.fields import NoneType, UnionField, SafeMapField
|
||||
|
||||
|
||||
class ModelLabels(MapField):
|
||||
class ModelLabels(SafeMapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
||||
super(ModelLabels, self).__init__(
|
||||
field=UnionField(types=(int, NoneType)), *args, **kwargs
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
if value and (len(set(value.values())) < len(value)):
|
||||
non_empty_values = list(filter(None, value.values()))
|
||||
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
@@ -17,12 +17,13 @@ class Project(AttributedDocument):
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "name"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
"default_language": "english",
|
||||
"weights": {"name": 10, "id": 10, "description": 10},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@@ -35,7 +36,11 @@ class Project(AttributedDocument):
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True))
|
||||
system_tags = ListField(StringField(required=True))
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -4,11 +4,10 @@ from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -7,6 +7,10 @@ from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class SettingKeys:
|
||||
server__uuid = "server.uuid"
|
||||
|
||||
|
||||
class Settings(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
@@ -47,7 +51,7 @@ class Settings(DbModelMixin, Document):
|
||||
""" Adds a new key/value settings. Fails if key already exists. """
|
||||
key = key.strip(sep)
|
||||
try:
|
||||
res = Settings(key=key, value=value).save(force_insert=True)
|
||||
res = cls(key=key, value=value).save(force_insert=True)
|
||||
return bool(res)
|
||||
except NotUniqueError:
|
||||
return False
|
||||
|
||||
@@ -18,7 +18,7 @@ from database.fields import (
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import ProperDictMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
@@ -49,13 +49,13 @@ class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
class Script(EmbeddedDocument, ProperDictMixin):
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(required=True)
|
||||
repository = StringField(default="")
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
entry_point = StringField(default="")
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
diff = StringField()
|
||||
@@ -84,7 +84,23 @@ class Artifact(EmbeddedDocument):
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
section = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
@@ -100,9 +116,29 @@ class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
class TaskType(object):
|
||||
training = "training"
|
||||
testing = "testing"
|
||||
inference = "inference"
|
||||
data_processing = "data_processing"
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -110,6 +146,13 @@ class Task(AttributedDocument):
|
||||
"created",
|
||||
"started",
|
||||
"completed",
|
||||
"parent",
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
@@ -134,6 +177,12 @@ class Task(AttributedDocument):
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
@@ -152,14 +201,19 @@ class Task(AttributedDocument):
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script, default=Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
|
||||
@@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
|
||||
@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1:], s, d)
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
@@ -96,6 +96,7 @@ class _ProxyManager:
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
|
||||
orig_field = field
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection).union({"id"})
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [
|
||||
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(
|
||||
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
||||
from mongoengine.queryset.visitor import (
|
||||
QueryCompilerVisitor,
|
||||
SimplificationVisitor,
|
||||
QCombination,
|
||||
QNode,
|
||||
)
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
@@ -17,17 +23,16 @@ class RegexWrapper(object):
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
|
||||
def to_query(self, document):
|
||||
def to_query(self: Union["RegexMixin", QNode], document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self, other, operation):
|
||||
def _combine(self: Union["RegexMixin", QNode], other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, 'empty', True):
|
||||
if getattr(other, "empty", True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
|
||||
@@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if callable(desc):
|
||||
if issubclass(desc, Document):
|
||||
if not desc.objects(id=value).only("id"):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
elif callable(desc):
|
||||
try:
|
||||
desc(value)
|
||||
except TypeError:
|
||||
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
||||
except Exception as ex:
|
||||
raise ParseCallError(str(ex), field=field)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(
|
||||
value, desc
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s" % desc.__name__, field=field
|
||||
)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only(
|
||||
"id"
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
@@ -4,53 +4,54 @@ Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
|
||||
session.mount('http://', adapter)
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
def apply_mappings_to_host(host: str):
|
||||
def _send_mapping(f):
|
||||
def _send_template(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
es_server = host
|
||||
url = f"{es_server}/_template/{f.stem}"
|
||||
|
||||
session.delete(url)
|
||||
r = session.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(data),
|
||||
)
|
||||
return {"mapping": f.stem, "result": r.text}
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
return [
|
||||
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
|
||||
]
|
||||
if key:
|
||||
files = (p / key).glob("*.json")
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("hosts", nargs="+")
|
||||
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
nargs="+",
|
||||
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
for host in parse_args().hosts:
|
||||
print(">>>>> Applying mapping to " + host)
|
||||
res = apply_mappings_to_host(host)
|
||||
print(res)
|
||||
args = parse_args()
|
||||
print(">>>>> Applying mapping to " + str(args.hosts))
|
||||
res = apply_mappings_to_cluster(args.hosts, args.key)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from furl import furl
|
||||
from time import sleep
|
||||
|
||||
from elasticsearch import Elasticsearch, exceptions
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
from elastic.apply_mappings import apply_mappings_to_cluster
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -15,13 +17,48 @@ class MissingElasticConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
class ElasticConnectionError(Exception):
|
||||
"""
|
||||
Exception when could not connect to elastic during init
|
||||
"""
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
pass
|
||||
|
||||
|
||||
def check_elastic_empty() -> bool:
|
||||
"""
|
||||
Check for elasticsearch connection
|
||||
Use probing settings and not the default es cluster ones
|
||||
so that we can handle correctly the connection rejects due to ES not fully started yet
|
||||
:return:
|
||||
"""
|
||||
cluster_conf = es_factory.get_cluster_config("events")
|
||||
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
|
||||
timeout = config.get("apiserver.elastic.probing.timeout", 30)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
return True
|
||||
except exceptions.ConnectionError:
|
||||
if retry >= max_retries - 1:
|
||||
raise ElasticConnectionError()
|
||||
log.warn(
|
||||
f"Could not connect to es server. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
|
||||
)
|
||||
sleep(timeout)
|
||||
|
||||
|
||||
def init_es_data():
|
||||
for name in es_factory.get_all_cluster_names():
|
||||
cluster_conf = es_factory.get_cluster_config(name)
|
||||
hosts_config = cluster_conf.get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration(f"for cluster '{name}'")
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"_routing": {
|
||||
"required": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": { "type": "date" },
|
||||
"task": { "type": "keyword" },
|
||||
"type": { "type": "keyword" },
|
||||
"worker": { "type": "keyword" },
|
||||
"timestamp": { "type": "date" },
|
||||
"iter": { "type": "long" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
40
server/elastic/mappings/events/events.json
Normal file
40
server/elastic/mappings/events/events.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
server/elastic/mappings/events/events_log.json
Normal file
15
server/elastic/mappings/events/events_log.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"index_patterns": "events-log-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"msg": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"level": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
server/elastic/mappings/events/events_plot.json
Normal file
12
server/elastic/mappings/events/events_plot.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"index_patterns": "events-plot-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"index_patterns": "events-training_debug_image-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"url": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-log-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"msg": { "type":"text", "index": false },
|
||||
"level": { "type":"keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"template": "events-plot-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"plot_str": { "type":"text", "index": false }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-training_debug_image-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"key": { "type": "keyword" },
|
||||
"url": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"metrics": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"template": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"stat": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": { "type": "date" },
|
||||
"worker": { "type": "keyword" },
|
||||
"category": { "type": "keyword" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" },
|
||||
"unit": { "type": "keyword" },
|
||||
"task": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
server/elastic/mappings/workers/worker_stats.json
Normal file
37
server/elastic/mappings/workers/worker_stats.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"category": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"unit": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,6 +65,10 @@ def connect(cluster_name):
|
||||
return _instances[cluster_name]
|
||||
|
||||
|
||||
def get_all_cluster_names():
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
|
||||
def get_cluster_config(cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.auth import Role
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
from .migration import _apply_migrations
|
||||
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
|
||||
from .pre_populate import PrePopulate
|
||||
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
|
||||
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
@@ -11,59 +13,76 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
log = config.logger(__package__)
|
||||
|
||||
|
||||
def _pre_populate(company_id: str, zip_file: str):
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Invalid pre-populate zip file: {zip_file}"
|
||||
if config.get("apiserver.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
log.info(f"Pre-populating using {zip_file}")
|
||||
|
||||
PrePopulate.import_from_zip(
|
||||
zip_file,
|
||||
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
|
||||
if isinstance(zip_files, str):
|
||||
zip_files = [zip_files]
|
||||
for p in map(Path, zip_files):
|
||||
if p.is_file():
|
||||
yield p
|
||||
if p.is_dir():
|
||||
yield from p.glob("*.zip")
|
||||
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
|
||||
|
||||
|
||||
def pre_populate_data():
|
||||
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
|
||||
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
|
||||
|
||||
PrePopulate.update_featured_projects_order()
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
empty_dbs = _apply_migrations(log)
|
||||
_apply_migrations(log)
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company(log)
|
||||
company_id = _ensure_company(get_default_company(), "trains", log)
|
||||
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
|
||||
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
|
||||
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
fixed_mode = FixedUser.enabled()
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
for user, credentials in config.get("secure.credentials", {}).items():
|
||||
user_data = {
|
||||
"name": user,
|
||||
"role": credentials.role,
|
||||
"email": f"{user}@example.com",
|
||||
"key": credentials.user_key,
|
||||
"secret": credentials.user_secret,
|
||||
}
|
||||
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
|
||||
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
|
||||
if credentials.role == Role.user:
|
||||
_ensure_backend_user(user_id, company_id, credentials.display_name)
|
||||
|
||||
PrePopulate.import_from_zip(zip_file, user_id=user_id)
|
||||
|
||||
users = [
|
||||
{
|
||||
"name": "apiserver",
|
||||
"role": Role.system,
|
||||
"email": "apiserver@example.com",
|
||||
},
|
||||
{
|
||||
"name": "webserver",
|
||||
"role": Role.system,
|
||||
"email": "webserver@example.com",
|
||||
},
|
||||
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
|
||||
]
|
||||
|
||||
for user in users:
|
||||
credentials = config.get(f"secure.credentials.{user['name']}")
|
||||
user["key"] = credentials.user_key
|
||||
user["secret"] = credentials.user_secret
|
||||
_ensure_auth_user(user, company_id, log=log)
|
||||
|
||||
if FixedUser.enabled():
|
||||
if fixed_mode:
|
||||
log.info("Fixed users mode is enabled")
|
||||
FixedUser.validate()
|
||||
|
||||
if FixedUser.guest_enabled():
|
||||
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
|
||||
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
ensure_fixed_user(user, company_id, log=log)
|
||||
ensure_fixed_user(user, log=log)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
except Exception as ex:
|
||||
|
||||
@@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
|
||||
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger) -> bool:
|
||||
def check_mongo_empty() -> bool:
|
||||
return not all(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
|
||||
def get_last_server_version() -> Version:
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
return previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger):
|
||||
"""
|
||||
Apply migrations as found in the migration dir.
|
||||
Returns a boolean indicating whether the database was empty prior to migration.
|
||||
@@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
|
||||
empty_dbs = not any(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
empty_dbs = check_mongo_empty()
|
||||
last_version = get_last_server_version()
|
||||
|
||||
try:
|
||||
new_scripts = {
|
||||
@@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
|
||||
).save()
|
||||
|
||||
log.info("Finished mongodb migrations")
|
||||
|
||||
return empty_dbs
|
||||
|
||||
@@ -1,31 +1,384 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from os.path import splitext
|
||||
from typing import List, Optional, Any, Type, Set, Dict
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Optional,
|
||||
Any,
|
||||
Type,
|
||||
Set,
|
||||
Dict,
|
||||
Sequence,
|
||||
Tuple,
|
||||
BinaryIO,
|
||||
Union,
|
||||
Mapping,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import dpath
|
||||
import mongoengine
|
||||
from tqdm import tqdm
|
||||
from boltons.iterutils import chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from bll.event import EventBLL
|
||||
from bll.task.param_utils import (
|
||||
split_param_name,
|
||||
hyperparams_default_section,
|
||||
hyperparams_legacy_type,
|
||||
)
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model import EntityVisibility
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from database.utils import get_options
|
||||
from tools import safe_get
|
||||
from utilities import json
|
||||
from .user import _ensure_backend_user
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
||||
cls._export(zfile, experiments, projects)
|
||||
event_bll = EventBLL()
|
||||
events_file_suffix = "_events"
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
metadata_filename = "metadata.json"
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
artifacts_ext = ".artifacts"
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
self.file = file
|
||||
self.empty = True
|
||||
|
||||
def __enter__(self):
|
||||
self._write("[")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self._write("\n]")
|
||||
|
||||
def _write(self, data: str):
|
||||
self.file.write(data.encode("utf-8"))
|
||||
|
||||
def write(self, line: str):
|
||||
if not self.empty:
|
||||
self._write(",")
|
||||
self._write("\n" + line)
|
||||
self.empty = False
|
||||
|
||||
@staticmethod
|
||||
def _get_last_update_time(entity) -> datetime:
|
||||
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(cls, filename: str, user_id: str = None):
|
||||
def _check_for_update(
|
||||
cls, map_file: Path, entities: dict, metadata_hash: str
|
||||
) -> Tuple[bool, Sequence[str]]:
|
||||
if not map_file.is_file():
|
||||
return True, []
|
||||
|
||||
files = []
|
||||
try:
|
||||
map_data = json.loads(map_file.read_text())
|
||||
files = map_data.get("files", [])
|
||||
for file in files:
|
||||
if not Path(file).is_file():
|
||||
return True, files
|
||||
|
||||
new_times = {
|
||||
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
||||
for item in chain.from_iterable(entities.values())
|
||||
}
|
||||
old_times = map_data.get("entities", {})
|
||||
|
||||
if set(new_times.keys()) != set(old_times.keys()):
|
||||
return True, files
|
||||
|
||||
for id_, new_timestamp in new_times.items():
|
||||
if new_timestamp != old_times[id_]:
|
||||
return True, files
|
||||
|
||||
if metadata_hash != map_data.get("metadata_hash", ""):
|
||||
return True, files
|
||||
|
||||
except Exception as ex:
|
||||
print("Error reading map file. " + str(ex))
|
||||
return True, files
|
||||
|
||||
return False, files
|
||||
|
||||
@classmethod
|
||||
def _write_update_file(
|
||||
cls,
|
||||
map_file: Path,
|
||||
entities: dict,
|
||||
created_files: Sequence[str],
|
||||
metadata_hash: str,
|
||||
):
|
||||
map_file.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
files=created_files,
|
||||
entities={
|
||||
entity.id: cls._get_last_update_time(entity)
|
||||
for entity in chain.from_iterable(entities.values())
|
||||
},
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
||||
def is_fileserver_link(a: str) -> bool:
|
||||
a = a.lower()
|
||||
if a.startswith("https://files."):
|
||||
return True
|
||||
if a.startswith("http"):
|
||||
parsed = urlparse(a)
|
||||
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
|
||||
"8081"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
|
||||
print(
|
||||
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
|
||||
)
|
||||
|
||||
return fileserver_links
|
||||
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
artifacts_path: str = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
) -> Sequence[str]:
|
||||
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||
raise ValueError("Invalid task statuses")
|
||||
|
||||
file = Path(filename)
|
||||
entities = cls._resolve_entities(
|
||||
experiments=experiments, projects=projects, task_statuses=task_statuses
|
||||
)
|
||||
|
||||
hash_ = hashlib.md5()
|
||||
if metadata:
|
||||
meta_str = json.dumps(metadata)
|
||||
hash_.update(meta_str.encode())
|
||||
metadata_hash = hash_.hexdigest()
|
||||
else:
|
||||
meta_str, metadata_hash = "", ""
|
||||
|
||||
map_file = file.with_suffix(".map")
|
||||
updated, old_files = cls._check_for_update(
|
||||
map_file, entities=entities, metadata_hash=metadata_hash
|
||||
)
|
||||
if not updated:
|
||||
print(f"There are no updates from the last export")
|
||||
return old_files
|
||||
|
||||
for old in old_files:
|
||||
old_path = Path(old)
|
||||
if old_path.is_file():
|
||||
old_path.unlink()
|
||||
|
||||
with ZipFile(file, **cls.zip_args) as zfile:
|
||||
if metadata:
|
||||
zfile.writestr(cls.metadata_filename, meta_str)
|
||||
artifacts = cls._export(
|
||||
zfile,
|
||||
entities=entities,
|
||||
hash_=hash_,
|
||||
tag_entities=tag_exported_entities,
|
||||
)
|
||||
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
||||
file.replace(file_with_hash)
|
||||
created_files = [str(file_with_hash)]
|
||||
|
||||
artifacts = cls._filter_artifacts(artifacts)
|
||||
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
|
||||
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
|
||||
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
||||
created_files.append(str(artifacts_file))
|
||||
|
||||
cls._write_update_file(
|
||||
map_file,
|
||||
entities=entities,
|
||||
created_files=created_files,
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
artifacts_path: str,
|
||||
company_id: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
):
|
||||
metadata = None
|
||||
|
||||
with ZipFile(filename) as zfile:
|
||||
cls._import(zfile, user_id)
|
||||
try:
|
||||
with zfile.open(cls.metadata_filename) as f:
|
||||
metadata = json.loads(f.read())
|
||||
|
||||
meta_public = metadata.get("public")
|
||||
if company_id is None and meta_public is not None:
|
||||
company_id = "" if meta_public else get_default_company()
|
||||
|
||||
if not user_id:
|
||||
meta_user_id = metadata.get("user_id", "")
|
||||
meta_user_name = metadata.get("user_name", "")
|
||||
user_id, user_name = meta_user_id, meta_user_name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
# Make sure we won't end up with an invalid company ID
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
user_id = _ensure_backend_user(
|
||||
user_id=user_id, user_name=user_name, company_id="",
|
||||
)
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
|
||||
if artifacts_file.is_file():
|
||||
print(f"Unzipping artifacts into {artifacts_path}")
|
||||
with ZipFile(artifacts_file) as zfile:
|
||||
zfile.extractall(artifacts_path)
|
||||
|
||||
@classmethod
|
||||
def upgrade_zip(cls, filename) -> Sequence:
|
||||
hash_ = hashlib.md5()
|
||||
task_file = cls._get_base_filename(Task) + ".json"
|
||||
temp_file = Path("temp.zip")
|
||||
file = Path(filename)
|
||||
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
|
||||
for file_info in reader.filelist:
|
||||
if file_info.orig_filename == task_file:
|
||||
with reader.open(file_info) as f:
|
||||
content = cls._upgrade_tasks(f)
|
||||
else:
|
||||
content = reader.read(file_info)
|
||||
writer.writestr(file_info, content)
|
||||
hash_.update(content)
|
||||
|
||||
base_file_name, _, old_hash = file.stem.rpartition("_")
|
||||
new_hash = hash_.hexdigest()
|
||||
if old_hash == new_hash:
|
||||
print(f"The file {filename} was not updated")
|
||||
temp_file.unlink()
|
||||
return []
|
||||
|
||||
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
|
||||
temp_file.replace(new_file)
|
||||
upadated = [str(new_file)]
|
||||
|
||||
artifacts_file = file.with_suffix(cls.artifacts_ext)
|
||||
if artifacts_file.is_file():
|
||||
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
|
||||
artifacts_file.replace(new_artifacts)
|
||||
upadated.append(str(new_artifacts))
|
||||
|
||||
print(f"File {str(file)} replaced with {str(new_file)}")
|
||||
file.unlink()
|
||||
|
||||
return upadated
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict):
|
||||
for old_param_field, new_param_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy = safe_get(task_data, old_param_field)
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_param_field, section, name)))
|
||||
if not safe_get(task_data, new_path):
|
||||
new_param = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(task_data, new_path, new_param)
|
||||
dpath.delete(task_data, old_param_field)
|
||||
|
||||
@classmethod
|
||||
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
|
||||
"""
|
||||
Build content array that contains fixed tasks from the passed file
|
||||
For each task the old execution.parameters and model.design are
|
||||
converted to the new structure.
|
||||
The fix is done on Task objects (not the dictionary) so that
|
||||
the fields are serialized back in the same order as they were in the original file
|
||||
"""
|
||||
with BytesIO() as temp:
|
||||
with cls.JsonLinesWriter(temp) as w:
|
||||
for line in cls.json_lines(f):
|
||||
task_data = Task.from_json(line).to_proper_dict()
|
||||
cls._upgrade_task_data(task_data)
|
||||
new_task = Task(**task_data)
|
||||
w.write(new_task.to_json())
|
||||
return temp.getvalue()
|
||||
|
||||
@classmethod
|
||||
def update_featured_projects_order(cls):
|
||||
featured_order = config.get("services.projects.featured_order", [])
|
||||
|
||||
def get_index(p: Project):
|
||||
for index, entry in enumerate(featured_order):
|
||||
if (
|
||||
entry.get("id", None) == p.id
|
||||
or entry.get("name", None) == p.name
|
||||
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
|
||||
):
|
||||
return index
|
||||
return 999
|
||||
|
||||
for project in Project.get_many_public(projection=["id", "name"]):
|
||||
featured_index = get_index(project)
|
||||
Project.objects(id=project.id).update(featured=featured_index)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
||||
) -> List[Any]:
|
||||
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
||||
) -> Sequence[Any]:
|
||||
ids = set(ids)
|
||||
items = list(cls.objects(id__in=list(ids)))
|
||||
resolved = {i.id for i in items}
|
||||
@@ -43,20 +396,24 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls, experiments: List[str] = None, projects: List[str] = None
|
||||
cls,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
|
||||
entities = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
entities[Project].update(cls._resolve_type(Project, projects))
|
||||
print("--> Reading project experiments...")
|
||||
objs = Task.objects(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
||||
query = Q(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
)
|
||||
if task_statuses:
|
||||
query &= Q(status__in=list(set(task_statuses)))
|
||||
objs = Task.objects(query)
|
||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||
|
||||
if experiments:
|
||||
@@ -69,85 +426,289 @@ class PrePopulate:
|
||||
project_ids = {p.id for p in entities[Project]}
|
||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
for task in entities[Task]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[Model] = set(Model.objects(id__in=list(model_ids)))
|
||||
|
||||
return entities
|
||||
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task):
|
||||
from database.model.task.task import TaskStatus
|
||||
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
|
||||
if not tags:
|
||||
return tags
|
||||
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
|
||||
|
||||
task.completed = None
|
||||
task.started = None
|
||||
if task.execution:
|
||||
task.execution.model = None
|
||||
task.execution.model_desc = None
|
||||
task.execution.model_labels = None
|
||||
if task.output:
|
||||
task.output.model = None
|
||||
@classmethod
|
||||
def _cleanup_model(cls, model: Model):
|
||||
model.company = ""
|
||||
model.user = ""
|
||||
model.tags = cls._filter_out_export_tags(model.tags)
|
||||
|
||||
task.status = TaskStatus.created
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task: Task):
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.created = datetime.utcnow()
|
||||
task.last_iteration = 0
|
||||
task.last_update = task.created
|
||||
task.status_changed = task.created
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
task.company = ""
|
||||
task.tags = cls._filter_out_export_tags(task.tags)
|
||||
if task.output:
|
||||
task.output.destination = None
|
||||
|
||||
@classmethod
|
||||
def _cleanup_project(cls, project: Project):
|
||||
project.user = ""
|
||||
project.company = ""
|
||||
project.tags = cls._filter_out_export_tags(project.tags)
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
from database.model.task.task import Task
|
||||
if entity_cls == Task:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == Model:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == Project:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
|
||||
try:
|
||||
for item in items:
|
||||
item.update(upsert=False, tags=sorted(item.tags + [tag]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _export_task_events(
|
||||
cls, task: Task, base_filename: str, writer: ZipFile, hash_
|
||||
) -> Sequence[str]:
|
||||
artifacts = []
|
||||
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
|
||||
print(f"Writing task events into {writer.filename}:{filename}")
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
scroll_id = None
|
||||
while True:
|
||||
res = cls.event_bll.get_task_events(
|
||||
task.company, task.id, scroll_id=scroll_id
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
scroll_id = res.next_scroll_id
|
||||
for event in res.events:
|
||||
if event.get("type") == "training_debug_image":
|
||||
url = cls._get_fixed_url(event.get("url"))
|
||||
if url:
|
||||
event["url"] = url
|
||||
artifacts.append(url)
|
||||
w.write(json.dumps(event))
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
|
||||
if not (url and url.lower().startswith("s3://")):
|
||||
return url
|
||||
try:
|
||||
fixed = furl(url)
|
||||
fixed.scheme = "https"
|
||||
fixed.host += ".s3.amazonaws.com"
|
||||
return fixed.url
|
||||
except Exception as ex:
|
||||
print(f"Failed processing link {url}. " + str(ex))
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def _export_entity_related_data(
|
||||
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
||||
):
|
||||
if entity_cls == Task:
|
||||
return [
|
||||
*cls._get_task_output_artifacts(entity),
|
||||
*cls._export_task_events(entity, base_filename, writer, hash_),
|
||||
]
|
||||
|
||||
if entity_cls == Model:
|
||||
entity.uri = cls._get_fixed_url(entity.uri)
|
||||
return [entity.uri] if entity.uri else []
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
|
||||
if not task.execution.artifacts:
|
||||
return []
|
||||
|
||||
for a in task.execution.artifacts:
|
||||
if a.mode == ArtifactModes.output:
|
||||
a.uri = cls._get_fixed_url(a.uri)
|
||||
|
||||
return [
|
||||
a.uri
|
||||
for a in task.execution.artifacts
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _export_artifacts(
|
||||
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
|
||||
):
|
||||
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
|
||||
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
|
||||
for path in unique_paths:
|
||||
path = path.lstrip("/")
|
||||
full_path = os.path.join(artifacts_path, path)
|
||||
if os.path.isfile(full_path):
|
||||
writer.write(full_path, path)
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@staticmethod
|
||||
def _get_base_filename(cls_: type):
|
||||
return f"{cls_.__module__}.{cls_.__name__}"
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
entities = cls._resolve_entities(experiments, projects)
|
||||
|
||||
for cls_, items in entities.items():
|
||||
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
if not items:
|
||||
continue
|
||||
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
)
|
||||
)
|
||||
filename = base_filename + ".json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with writer.open(filename, "w") as f:
|
||||
f.write("[\n".encode("utf-8"))
|
||||
last = len(items) - 1
|
||||
for i, item in enumerate(items):
|
||||
cls._cleanup_entity(cls_, item)
|
||||
f.write(item.to_json().encode("utf-8"))
|
||||
if i != last:
|
||||
f.write(",".encode("utf-8"))
|
||||
f.write("\n".encode("utf-8"))
|
||||
f.write("]\n".encode("utf-8"))
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for item in items:
|
||||
cls._cleanup_entity(cls_, item)
|
||||
w.write(item.to_json())
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
if tag_entities:
|
||||
cls._add_tag(items, now.strftime(cls.export_tag))
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _import(reader: ZipFile, user_id: str = None):
|
||||
for file_info in reader.filelist:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
def json_lines(file: BinaryIO):
|
||||
for line in file:
|
||||
clean = (
|
||||
line.decode("utf-8")
|
||||
.rstrip("\r\n")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip(",]")
|
||||
.strip()
|
||||
)
|
||||
if not clean:
|
||||
continue
|
||||
yield clean
|
||||
|
||||
with reader.open(file_info) as f:
|
||||
for item in tqdm(
|
||||
f.readlines(),
|
||||
desc=f"Writing {cls_.__name__.lower()}s into database",
|
||||
unit="doc",
|
||||
):
|
||||
item = (
|
||||
item.decode("utf-8")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip("]")
|
||||
.rstrip(",")
|
||||
.strip()
|
||||
)
|
||||
if not item:
|
||||
continue
|
||||
doc = cls_.from_json(item)
|
||||
if user_id is not None and hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.save(force_insert=True)
|
||||
@classmethod
|
||||
def _import(
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
company_id: str = "",
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
Start from entities since event import will require the tasks already in DB
|
||||
"""
|
||||
event_file_ending = cls.events_file_suffix + ".json"
|
||||
entity_files = (
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
)
|
||||
event_files = (
|
||||
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
for files, reader_func in (
|
||||
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
|
||||
(event_files, cls._import_events),
|
||||
):
|
||||
for file_info in files:
|
||||
with reader.open(file_info) as f:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
reader_func(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(
|
||||
cls,
|
||||
f: BinaryIO,
|
||||
full_name: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
):
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, Project):
|
||||
override_project_name = metadata.get("project_name", None)
|
||||
if override_project_name:
|
||||
if override_project_count:
|
||||
override_project_name = (
|
||||
f"{override_project_name} {override_project_count + 1}"
|
||||
)
|
||||
override_project_count += 1
|
||||
doc.name = override_project_name
|
||||
|
||||
doc.logo_url = metadata.get("logo_url", None)
|
||||
doc.logo_blob = metadata.get("logo_blob", None)
|
||||
|
||||
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
||||
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
)
|
||||
|
||||
doc.save()
|
||||
|
||||
if isinstance(doc, Task):
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
|
||||
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
)
|
||||
|
||||
@@ -9,28 +9,34 @@ from database.model.user import User
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger):
|
||||
ensure_credentials = {"key", "secret"}.issubset(user_data)
|
||||
if ensure_credentials:
|
||||
user = AuthUser.objects(
|
||||
credentials__match=Credentials(
|
||||
key=user_data["key"], secret=user_data["secret"]
|
||||
)
|
||||
).first()
|
||||
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
|
||||
key, secret = user_data.get("key"), user_data.get("secret")
|
||||
if not (key and secret):
|
||||
credentials = None
|
||||
else:
|
||||
creds = Credentials(key=key, secret=secret)
|
||||
|
||||
user = AuthUser.objects(credentials__match=creds).first()
|
||||
if user:
|
||||
if revoke:
|
||||
user.credentials = []
|
||||
user.save()
|
||||
return user.id
|
||||
|
||||
credentials = [] if revoke else [creds]
|
||||
|
||||
user_id = user_data.get("id", f"__{user_data['name']}__")
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
|
||||
user = AuthUser(
|
||||
id=user_data.get("id", f"__{user_data['name']}__"),
|
||||
id=user_id,
|
||||
name=user_data["name"],
|
||||
company=company_id,
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
|
||||
if ensure_credentials
|
||||
else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
user.save()
|
||||
@@ -52,23 +58,15 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
|
||||
if User.objects(id=user.user_id).first():
|
||||
def ensure_fixed_user(user: FixedUser, log: Logger):
|
||||
if User.objects(company=user.company, id=user.user_id).first():
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
data["role"] = Role.guest if user.is_guest else Role.user
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
|
||||
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user.user_id,
|
||||
company=company_id,
|
||||
name=user.name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
return _ensure_backend_user(user.user_id, user.company, user.name)
|
||||
|
||||
@@ -3,21 +3,18 @@ from uuid import uuid4
|
||||
|
||||
from bll.queue import QueueBLL
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings
|
||||
from database.model.settings import Settings, SettingKeys
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def _ensure_company(log: Logger):
|
||||
company_id = get_default_company()
|
||||
def _ensure_company(company_id, company_name, log: Logger):
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
@@ -37,4 +34,4 @@ def _ensure_default_queue(company):
|
||||
|
||||
|
||||
def _ensure_uuid():
|
||||
Settings.add_value("server.uuid", str(uuid4()))
|
||||
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))
|
||||
|
||||
58
server/mongo/migrations/0.15.0.py
Normal file
58
server/mongo/migrations/0.15.0.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections import Collection
|
||||
from typing import Sequence
|
||||
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
for collection_name in db.list_collection_names():
|
||||
if collection_name not in names:
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
||||
|
||||
|
||||
def migrate_auth(db: Database):
|
||||
"""
|
||||
Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
_drop_all_indices_from_collections(db, ["user"])
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
"""
|
||||
1. Sort tags and system tags
|
||||
2. Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
|
||||
fields = ("tags", "system_tags")
|
||||
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
|
||||
for collection_name in ("task", "model", "project", "queue"):
|
||||
collection = db[collection_name]
|
||||
for doc in collection.find(filter=query, projection=fields):
|
||||
update = {
|
||||
field: sorted(doc[field])
|
||||
for field in fields
|
||||
if doc.get(field)
|
||||
}
|
||||
if update:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": update})
|
||||
|
||||
_drop_all_indices_from_collections(
|
||||
db,
|
||||
[
|
||||
"company",
|
||||
"model",
|
||||
"project",
|
||||
"queue",
|
||||
"settings",
|
||||
"task",
|
||||
"task__trash",
|
||||
"user",
|
||||
"versions",
|
||||
],
|
||||
)
|
||||
36
server/mongo/migrations/0.16.0.py
Normal file
36
server/mongo/migrations/0.16.0.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
from bll.task.param_utils import (
|
||||
hyperparams_legacy_type,
|
||||
hyperparams_default_section,
|
||||
split_param_name,
|
||||
)
|
||||
from tools import safe_get
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
hyperparam_fields = ("execution.parameters", "hyperparams")
|
||||
configuration_fields = ("execution.model_desc", "configuration")
|
||||
collection: Collection = db["task"]
|
||||
for doc in collection.find(projection=hyperparam_fields + configuration_fields):
|
||||
set_commands = {}
|
||||
for (old_field, new_field), default_section in zip(
|
||||
(hyperparam_fields, configuration_fields),
|
||||
(hyperparams_default_section, None),
|
||||
):
|
||||
legacy = safe_get(doc, old_field, separator=".")
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_field, section, name)))
|
||||
# if safe_get(doc, new_path) is not None:
|
||||
# continue
|
||||
new_value = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_value["section"] = section
|
||||
set_commands[".".join(new_path)] = new_value
|
||||
if set_commands:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||
@@ -1,7 +1,8 @@
|
||||
attrs>=19.1.0
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
fastjsonschema>=2.8
|
||||
Flask-Compress>=1.4.0
|
||||
Flask-Cors>=3.0.5
|
||||
@@ -14,17 +15,17 @@ Jinja2==2.10
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.7.2
|
||||
mongoengine==0.16.2
|
||||
mongoengine==0.19.1
|
||||
nested_dict>=1.61
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt>=1.3.0
|
||||
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
|
||||
pymongo==3.10.1
|
||||
python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.0,<3
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
||||
@@ -328,6 +328,11 @@ fixed_users_mode {
|
||||
description: "Fixed users mode enabled"
|
||||
type: boolean
|
||||
}
|
||||
server_errors {
|
||||
description: "Server initialization errors"
|
||||
type: object
|
||||
additionalProperties: True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
16
server/schema/services/debug.conf
Normal file
16
server/schema/services/debug.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
_description: "debugging utilities"
|
||||
ping {
|
||||
authorize: false
|
||||
"2.9" {
|
||||
description: "Ping server"
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties: {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,6 +258,7 @@
|
||||
properties {
|
||||
added { type: integer }
|
||||
errors { type: integer }
|
||||
errors_info { type: object }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -362,7 +363,7 @@
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then events are retreived from later iterations to earlier ones. Otherwise from earlier iterations to the later. The default is True"
|
||||
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
|
||||
}
|
||||
refresh {
|
||||
type: boolean
|
||||
@@ -529,6 +530,56 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.9" {
|
||||
description: "Get 'log' events for this task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
batch_size {
|
||||
type: integer
|
||||
description: "The amount of log events to return"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order, unless order='asc'). Otherwise from the earliest to the latest ones (in timestamp ascending order, unless order='desc'). The default is True"
|
||||
}
|
||||
from_timestamp {
|
||||
type: number
|
||||
description: "Epoch time in UTC ms to use as the navigation start. Optional. If not provided, reference timestamp is determined by the 'navigate_earlier' parameter (if true, reference timestamp is the last timestamp and if false, reference timestamp is the first timestamp)"
|
||||
}
|
||||
order {
|
||||
type: string
|
||||
description: "If set, changes the order in which log events are returned based on the value of 'navigate_earlier'"
|
||||
enum: [asc, desc]
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
type: array
|
||||
items { type: object }
|
||||
description: "Log items list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of log events returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of log events available for this query"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_events {
|
||||
"2.1" {
|
||||
@@ -802,7 +853,7 @@
|
||||
description: "Task ID"
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
@@ -840,7 +891,7 @@
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs"
|
||||
description: "List of task Task IDs. Maximum amount of tasks is 10"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
@@ -848,7 +899,7 @@
|
||||
}
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return. Optional, the default value is 6000"
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
48
server/schema/services/organization.conf
Normal file
48
server/schema/services/organization.conf
Normal file
@@ -0,0 +1,48 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,6 +196,52 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
projects {
|
||||
description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
@@ -359,6 +405,11 @@ get_all_ex {
|
||||
enum: [ active, archived ]
|
||||
default: active
|
||||
}
|
||||
non_public {
|
||||
description: "Return only non-public projects"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -481,8 +532,8 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.2" {
|
||||
description: """Get a list of all hyper parameter names used in tasks within the given project."""
|
||||
"2.9" {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
@@ -506,9 +557,9 @@ get_hyper_parameters {
|
||||
type: object
|
||||
properties {
|
||||
parameters {
|
||||
description: "A list of hyper parameter names"
|
||||
description: "A list of parameter sections and names"
|
||||
type: array
|
||||
items { type: string }
|
||||
items {type: object}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
@@ -522,3 +573,69 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company projects to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public projects to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -69,6 +69,17 @@ info {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.8": ${info."2.1"} {
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
uid {
|
||||
description: "Server UID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
endpoints {
|
||||
"2.1" {
|
||||
|
||||
@@ -254,6 +254,15 @@ _definitions {
|
||||
enum: [
|
||||
training
|
||||
testing
|
||||
inference
|
||||
data_processing
|
||||
application
|
||||
monitor
|
||||
controller
|
||||
optimizer
|
||||
service
|
||||
qc
|
||||
custom
|
||||
]
|
||||
}
|
||||
last_metrics_event {
|
||||
@@ -288,7 +297,80 @@ _definitions {
|
||||
"$ref": "#/definitions/last_metrics_event"
|
||||
}
|
||||
}
|
||||
|
||||
params_item {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. The combination of section and name should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
configuration_item {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Name of the parameter. Should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
param_key {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. If the name is ommitted then the corresponding operation is performed on the whole section"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
section_params {
|
||||
description: "Task section params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/params_item"
|
||||
}
|
||||
}
|
||||
replace_hyperparams_enum {
|
||||
type: string
|
||||
enum: [
|
||||
none,
|
||||
section,
|
||||
all
|
||||
]
|
||||
}
|
||||
task {
|
||||
type: object
|
||||
properties {
|
||||
@@ -409,9 +491,24 @@ _definitions {
|
||||
"$ref": "#/definitions/last_metrics_variants"
|
||||
}
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: "Gets task information"
|
||||
@@ -475,7 +572,11 @@ get_all {
|
||||
minimum: 1
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
description: """List of field names to order by. When search_text is used,
|
||||
'@text_score' can be used as a field representing the text score of returned documents.
|
||||
Use '-' prefix to specify descending order. Optional, recommended when using page.
|
||||
If the first order field is a hyper parameter or metric then string values are ordered
|
||||
according to numeric ordering rules where applicable"""
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
@@ -550,6 +651,31 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_types {
|
||||
"2.8" {
|
||||
description: "Get the list of task types used in the specified projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
types {
|
||||
description: "Unique list of the task types used in the requested projects"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
clone {
|
||||
"2.5" {
|
||||
description: "Clone an existing task"
|
||||
@@ -587,10 +713,28 @@ clone {
|
||||
description: "The project of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
new_task_hyperparams {
|
||||
description: "The hyper params for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
new_task_configuration {
|
||||
description: "The configuration for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
execution_overrides {
|
||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
validate_references {
|
||||
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -656,6 +800,20 @@ create {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -717,6 +875,20 @@ validate {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
@@ -867,6 +1039,20 @@ edit {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
@@ -901,6 +1087,11 @@ reset {
|
||||
properties.force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is 'completed'"
|
||||
}
|
||||
properties.clear_all {
|
||||
description: "Clear script and execution sections completely"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
@@ -1394,4 +1585,263 @@ add_or_update_artifacts {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company tasks to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public tasks to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_hyper_params {
|
||||
"2.9": {
|
||||
description: "Get the list of task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
params {
|
||||
type: object
|
||||
description: "Hyper parameters (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_hyper_params {
|
||||
"2.9" {
|
||||
description: "Add or update task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, hyperparams ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper parameters. The new ones will be added and the already existing ones will be updated"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/params_item"}
|
||||
}
|
||||
replace_hyperparams {
|
||||
description: """Can be set to one of the following:
|
||||
'all' - all the hyper parameters will be replaced with the provided ones
|
||||
'section' - the sections that present in the new parameters will be replaced with the provided parameters
|
||||
'none' (the default value) - only the specific parameters will be updated or added"""
|
||||
"$ref": "#/definitions/replace_hyperparams_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_hyper_params {
|
||||
"2.9": {
|
||||
description: "Delete task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, hyperparams ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
hyperparams {
|
||||
description: "List of hyper parameters to delete. In case a parameter with an empty name is passed all the section will be deleted"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/param_key" }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_configurations {
|
||||
"2.9": {
|
||||
description: "Get the list of task configurations"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
names {
|
||||
description: "Names of the configuration items to retreive. If not passed or empty then all the configurations will be retreived."
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
configurations {
|
||||
type: object
|
||||
description: "Configurations (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_configuration_names {
|
||||
"2.9": {
|
||||
description: "Get the list of task configuration items names"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
configurations {
|
||||
type: object
|
||||
description: "Names of task configuration items (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_configuration {
|
||||
"2.9" {
|
||||
description: "Add or update task configuration"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, configuration ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration items. The new ones will be added and the already existing ones will be updated"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/configuration_item"}
|
||||
}
|
||||
replace_configuration {
|
||||
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_configuration {
|
||||
"2.9": {
|
||||
description: "Delete task configuration items"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, configuration ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
configuration {
|
||||
description: "List of configuration itemss to delete"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,19 @@ get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
}
|
||||
"2.8": ${get_all."2.1"} {
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
active_in_projects {
|
||||
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
get_all {
|
||||
|
||||
@@ -135,6 +135,10 @@
|
||||
description: "Task currently being run by the worker"
|
||||
"$ref": "#/definitions/current_task_entry"
|
||||
}
|
||||
project {
|
||||
description: "Project in which currently executing task resides"
|
||||
"$ref": "#/definitions/id_name_entry"
|
||||
}
|
||||
queue {
|
||||
description: "Queue from which running task was taken"
|
||||
"$ref": "#/definitions/queue_entry"
|
||||
@@ -144,6 +148,11 @@
|
||||
type: array
|
||||
items { "$ref": "#/definitions/queue_entry" }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,11 +160,11 @@
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Worker ID"
|
||||
description: "ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Worker name"
|
||||
description: "Name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
@@ -301,6 +310,11 @@
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -363,6 +377,11 @@
|
||||
description: "The machine statistics."
|
||||
"$ref": "#/definitions/machine_stats"
|
||||
}
|
||||
tags {
|
||||
description: "New user tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import atexit
|
||||
from argparse import ArgumentParser
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from elastic.initialize import init_es_data
|
||||
from mongo.initialize import init_mongo_data
|
||||
from config import config, info
|
||||
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
|
||||
from mongo.initialize import (
|
||||
init_mongo_data,
|
||||
pre_populate_data,
|
||||
check_mongo_empty,
|
||||
get_last_server_version,
|
||||
)
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from sync import distributed_lock
|
||||
from timing_context import TimingContext
|
||||
from updates import check_updates_thread
|
||||
from utilities import json
|
||||
@@ -33,8 +41,42 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
||||
|
||||
database.initialize()
|
||||
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
# build a key that uniquely identifies specific mongo instance
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
|
||||
upgrade_monitoring = config.get(
|
||||
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
|
||||
)
|
||||
try:
|
||||
empty_es = check_elastic_empty()
|
||||
except ElasticConnectionError as err:
|
||||
if not upgrade_monitoring:
|
||||
raise
|
||||
log.error(err)
|
||||
info.es_connection_error = True
|
||||
|
||||
empty_db = check_mongo_empty()
|
||||
if upgrade_monitoring:
|
||||
if not empty_db and (info.es_connection_error or empty_es):
|
||||
if get_last_server_version() < Version("0.16.0"):
|
||||
log.info(f"ES database seems not migrated")
|
||||
info.missed_es_upgrade = True
|
||||
proceed_with_init = not (info.es_connection_error or info.missed_es_upgrade)
|
||||
else:
|
||||
proceed_with_init = True
|
||||
|
||||
if proceed_with_init:
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
|
||||
if (
|
||||
proceed_with_init
|
||||
and empty_db
|
||||
and config.get("apiserver.pre_populate.enabled", False)
|
||||
):
|
||||
pre_populate_data()
|
||||
|
||||
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
@@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
|
||||
from .service_repo import ServiceRepo
|
||||
|
||||
|
||||
__all__ = ["endpoint"]
|
||||
__all__ = ["APICall", "endpoint"]
|
||||
|
||||
|
||||
LegacyEndpointFunc = Callable[[APICall], None]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user