mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b93591ec32 | ||
|
|
0abfd8da0d | ||
|
|
a9cc4e36c6 | ||
|
|
fe1c963eec | ||
|
|
111d80e88d | ||
|
|
6718862dbe | ||
|
|
0fe1bf8a61 | ||
|
|
10f326eda9 | ||
|
|
cd0d6c1a3d | ||
|
|
3205f2df97 | ||
|
|
5bdbcfcd8d | ||
|
|
a2e2052b30 | ||
|
|
0146ded4f4 | ||
|
|
dccf9dd8f8 | ||
|
|
7816b402bb | ||
|
|
cd4ce30f7c | ||
|
|
8c7e230898 | ||
|
|
42ba696518 | ||
|
|
3f84e60a1f | ||
|
|
baba8b5b73 | ||
|
|
77397c4f21 | ||
|
|
8678091d8f | ||
|
|
aa22170ab4 | ||
|
|
901ec37290 | ||
|
|
21f2ea8b17 | ||
|
|
8219e3d4e2 | ||
|
|
3ed71a61d5 | ||
|
|
18a88a8e8f | ||
|
|
318a72987c |
28
README.md
28
README.md
@@ -11,6 +11,12 @@
|
||||
|
||||
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
|
||||
|
||||
## v0.16 Upgrade Notice
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of Trains Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
@@ -64,15 +70,15 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch **trains-server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
|
||||
- Pre-built [GCP Custom Image](https://github.com/allegroai/trains-server/blob/master/docs/install_gcp.md)
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
|
||||
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
|
||||
## Connecting Trains to your trains-server
|
||||
|
||||
@@ -124,8 +130,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
|
||||
|
||||
**trains-server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
|
||||
* [Web login authentication](https://allegro.ai/docs/faq/faq/#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/docs/faq/faq/#watchdog)
|
||||
|
||||
## Restarting trains-server
|
||||
|
||||
@@ -191,12 +197,12 @@ To upgrade your existing **trains-server** deployment:
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/docs/faq/faq/#common-docker-upgrade-errors).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
|
||||
If you have any questions, look to the Trains [FAQ](https://allegro.ai/docs/faq/faq/), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
|
||||
@@ -40,15 +40,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -58,10 +54,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
mongo:
|
||||
|
||||
@@ -40,15 +40,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -58,10 +54,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@@ -23,8 +24,9 @@ services:
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
TRAINS__apiserver__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
|
||||
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
@@ -40,15 +42,11 @@ services:
|
||||
cluster.name: trains
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
discovery.zen.minimum_master_nodes: "1"
|
||||
discovery.type: "single-node"
|
||||
http.compression_level: "7"
|
||||
node.ingest: "true"
|
||||
node.name: trains
|
||||
reindex.remote.whitelist: '*.*'
|
||||
script.inline: "true"
|
||||
script.painless.regex.enabled: "true"
|
||||
script.update: "true"
|
||||
thread_pool.bulk.queue_size: "2000"
|
||||
thread_pool.search.queue_size: "10000"
|
||||
xpack.monitoring.enabled: "false"
|
||||
xpack.security.enabled: "false"
|
||||
ulimits:
|
||||
@@ -58,10 +56,10 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
|
||||
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
|
||||
@@ -54,41 +54,41 @@ The following sections contain lists of AMI Image IDs, per region, for each rele
|
||||
|
||||
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
|
||||
|
||||
* **eu-north-1** : ami-0f63429f8e5d57315
|
||||
* **ap-south-1** : ami-058a2a70b7fb8ec87
|
||||
* **eu-west-3** : ami-0fc9f9e8e986f39c4
|
||||
* **eu-west-2** : ami-0b0bc1ff2f0239bd9
|
||||
* **eu-west-1** : ami-0056ec5d22b0fac91
|
||||
* **ap-northeast-2** : ami-0898c9aa7f580fec7
|
||||
* **ap-northeast-1** : ami-011036ddcc9398871
|
||||
* **sa-east-1** : ami-04feeded12192438c
|
||||
* **ca-central-1** : ami-02c717776c9e75025
|
||||
* **ap-southeast-1** : ami-05b5866e7029bb9f1
|
||||
* **ap-southeast-2** : ami-0384bd2b69467fff8
|
||||
* **eu-central-1** : ami-01f15be85297d6f06
|
||||
* **us-east-2** : ami-094070ca8aa110180
|
||||
* **us-west-1** : ami-0d08ec5bc29eddb29
|
||||
* **us-west-2** : ami-04715cceedaf6eae7
|
||||
* **us-east-1** : ami-071dbaa1847585c4c
|
||||
* **eu-north-1** : ami-0f30c84b905d354b9
|
||||
* **ap-south-1** : ami-050e7acec52c8c74e
|
||||
* **eu-west-3** : ami-03911c5b5bc77ef75
|
||||
* **eu-west-2** : ami-0a5ed8aa2573ccc70
|
||||
* **eu-west-1** : ami-0a53c65e922ec0611
|
||||
* **ap-northeast-2** : ami-08cd017a37b8e8aab
|
||||
* **ap-northeast-1** : ami-056b3ca1ad5af9322
|
||||
* **sa-east-1** : ami-01ddc9325bafb400c
|
||||
* **ca-central-1** : ami-0fc3cbbd982b18b45
|
||||
* **ap-southeast-1** : ami-04c7a358df7002ef5
|
||||
* **ap-southeast-2** : ami-0eeaf54231b4ae22a
|
||||
* **eu-central-1** : ami-00b8e44041f8175fd
|
||||
* **us-east-2** : ami-0ac7deebb3f738f6d
|
||||
* **us-west-1** : ami-06bc07deb8b8c44d6
|
||||
* **us-west-2** : ami-01ba85ffe79a422f1
|
||||
* **us-east-1** : ami-04cf5a66cb4928ac3
|
||||
|
||||
### v0.15.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0bb36c4dbe61f8c46
|
||||
* **ap-south-1** : ami-0ac93ff85a5c770f9
|
||||
* **eu-west-3** : ami-015ebfa846b8de5bb
|
||||
* **eu-west-2** : ami-082aacd59408713d9
|
||||
* **eu-west-1** : ami-066aad8c6b9b9991b
|
||||
* **ap-northeast-2** : ami-0cb47f1c8591c799d
|
||||
* **ap-northeast-1** : ami-005131d3037da9d2a
|
||||
* **sa-east-1** : ami-0f7fdc4e19c8444a3
|
||||
* **ca-central-1** : ami-07c234dad3ece2d78
|
||||
* **ap-southeast-1** : ami-0d8e0475d7d4897e4
|
||||
* **ap-southeast-2** : ami-053e3f25dee0424b9
|
||||
* **eu-central-1** : ami-00d25558c5242708e
|
||||
* **us-east-2** : ami-0bd45f800dfbde456
|
||||
* **us-west-1** : ami-05e79bf1704721148
|
||||
* **us-west-2** : ami-037c328649048409b
|
||||
* **us-east-1** : ami-0a3cafe46bf085200
|
||||
* **eu-north-1** : ami-0cd314e267426d1b7
|
||||
* **ap-south-1** : ami-086182cbe29151f96
|
||||
* **eu-west-3** : ami-0062366012182815b
|
||||
* **eu-west-2** : ami-022b8f2e32a9d18d0
|
||||
* **eu-west-1** : ami-0d8cf60446e09aa3d
|
||||
* **ap-northeast-2** : ami-0d4c168a815b56889
|
||||
* **ap-northeast-1** : ami-0daf7887db1053ae4
|
||||
* **sa-east-1** : ami-020a759a3ba4ff22b
|
||||
* **ca-central-1** : ami-0c10b5e04b707f3e3
|
||||
* **ap-southeast-1** : ami-0f61bb3529a165fcd
|
||||
* **ap-southeast-2** : ami-032dcdc82749c66c5
|
||||
* **eu-central-1** : ami-08f364f32d2eb3bae
|
||||
* **us-east-2** : ami-0b7efc3591803eba4
|
||||
* **us-west-1** : ami-08b2df27b0ada6faf
|
||||
* **us-west-2** : ami-0693029c4bad28816
|
||||
* **us-east-1** : ami-0200954fa9c2819ff
|
||||
|
||||
### v0.15.0 (static update)
|
||||
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
|
||||
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
|
||||
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
|
||||
#### Default Trains Server Service ports
|
||||
The service port numbers on our Trains Server GCP Custom Image are:
|
||||
|
||||
- Web application: `8080`
|
||||
- API Server: `8008`
|
||||
- File Server: `8081`
|
||||
|
||||
#### Default Trains Server Storage paths
|
||||
The persistent storage configuration:
|
||||
|
||||
- MongoDB: `/opt/trains/data/mongo/`
|
||||
@@ -49,6 +52,15 @@ The minimum recommended requirements for Trains Server are:
|
||||
|
||||
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
## Network and Security
|
||||
|
||||
Please make sure your instance is properly secured.
|
||||
|
||||
If not specifically set, a GCP instance will use default firewall rules that allow public access to various ports.
|
||||
If your instance is open for public access, we recommend you follow best practices for access management, including:
|
||||
- Allow access only to the specific ports used by Trains Server (see [Default Trains Server Service ports](#default-trains-server-service-ports)). Remember to allow access to port `443` if `https` access is configured for your instance.
|
||||
- Configure Trains Server to use fixed user names and passwords (see [Can I add web login authentication to trains-server?](./faq.md#web-auth))
|
||||
|
||||
## Released versions
|
||||
|
||||
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
|
||||
@@ -59,5 +71,6 @@ The following sections contain lists of Custom Image URLs (exported in different
|
||||
|
||||
### All released images
|
||||
|
||||
- v0.15.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-1.tar.gz
|
||||
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
|
||||
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz
|
||||
@@ -1,6 +1,6 @@
|
||||
# Launching the **trains-server** Docker in Linux or macOS
|
||||
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
|
||||
For Linux users:
|
||||
|
||||
@@ -16,20 +16,20 @@ To launch **trains-server** on Linux or macOS:
|
||||
|
||||
1. Verify the Docker CE installation. Execute the command:
|
||||
|
||||
sudo docker run hello-world
|
||||
|
||||
docker run hello-world
|
||||
|
||||
The expected is output is:
|
||||
|
||||
Hello from Docker!
|
||||
This message shows that your installation appears to be working correctly.
|
||||
To generate this message, Docker took the following steps:
|
||||
|
||||
|
||||
1. The Docker client contacted the Docker daemon.
|
||||
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
|
||||
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
|
||||
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
|
||||
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
|
||||
|
||||
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
|
||||
sudo chmod +x /usr/local/bin/docker-compose
|
||||
@@ -42,12 +42,12 @@ To launch **trains-server** on Linux or macOS:
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
sudo service docker restart
|
||||
|
||||
|
||||
macOS:
|
||||
|
||||
|
||||
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
|
||||
sysctl -w vm.max_map_count=262144
|
||||
|
||||
|
||||
|
||||
1. Remove any previous installation of **trains-server**.
|
||||
|
||||
@@ -64,28 +64,28 @@ To launch **trains-server** on Linux or macOS:
|
||||
sudo mkdir -p /opt/trains/logs
|
||||
sudo mkdir -p /opt/trains/config
|
||||
sudo mkdir -p /opt/trains/data/fileserver
|
||||
|
||||
|
||||
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
|
||||
|
||||
|
||||
1. Grant access to the Dockers.
|
||||
|
||||
Linux:
|
||||
|
||||
sudo chown -R 1000:1000 /opt/trains
|
||||
|
||||
|
||||
macOS:
|
||||
|
||||
|
||||
sudo chown -R $(whoami):staff /opt/trains
|
||||
|
||||
1. Download the **trains-server** docker-compose YAML file.
|
||||
|
||||
cd /opt/trains
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
|
||||
|
||||
1. Run `docker-compose` with the downloaded configuration file.
|
||||
|
||||
sudo docker-compose -f docker-compose.yml up
|
||||
|
||||
docker-compose -f docker-compose.yml up
|
||||
|
||||
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
|
||||
|
||||
* Web server on port `8080`
|
||||
@@ -94,4 +94,4 @@ To launch **trains-server** on Linux or macOS:
|
||||
|
||||
## Next Step
|
||||
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "2.8.0"
|
||||
__version__ = "2.9.0"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
from mongoengine.base import BaseDocument
|
||||
|
||||
from apimodels import DictField
|
||||
from apimodels import DictField, ListField
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
@@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
'inc',
|
||||
'dec',
|
||||
'push',
|
||||
'push_all',
|
||||
'pop',
|
||||
'pull',
|
||||
'pull_all',
|
||||
'add_to_set',
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split('__')
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == 'set':
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == 'unset':
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return '.'.join(parts), cls._normalize_mongo_value(value)
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
@@ -62,3 +63,7 @@ class PagedRequest(models.Base):
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from typing import Sequence
|
||||
from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
from config import config
|
||||
from utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=10000)
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
@@ -21,7 +24,15 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
items_types=str,
|
||||
validators=[
|
||||
Length(
|
||||
minimum_value=1,
|
||||
maximum_value=config.get(
|
||||
"services.tasks.multi_task_histogram_limit", 10
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -40,12 +51,17 @@ class DebugImagesRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
asc = auto()
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
|
||||
@@ -6,6 +6,10 @@ from apimodels.base import UpdateResponse
|
||||
from apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
projects = fields.ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
|
||||
@@ -13,11 +13,5 @@ class GetHyperParamReq(ProjectReq):
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apimodels import DictField, ListField
|
||||
from apimodels.base import UpdateResponse
|
||||
@@ -103,6 +105,8 @@ class CloneRequest(TaskRequest):
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_hyperparams = DictField()
|
||||
new_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
@@ -118,3 +122,72 @@ class AddOrUpdateArtifactsResponse(models.Base):
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
@@ -19,6 +19,7 @@ DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@@ -67,12 +68,14 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
|
||||
@@ -208,7 +208,11 @@ class DebugImagesIterator:
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
@@ -251,7 +255,7 @@ class DebugImagesIterator:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -298,6 +302,7 @@ class DebugImagesIterator:
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
@@ -368,7 +373,7 @@ class DebugImagesIterator:
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
@@ -387,7 +392,7 @@ class DebugImagesIterator:
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
@@ -22,6 +22,7 @@ from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@@ -31,6 +32,7 @@ LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@@ -40,7 +42,7 @@ class EventBLL(object):
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
@@ -134,7 +136,6 @@ class EventBLL(object):
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
@@ -144,7 +145,6 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
@@ -322,6 +322,9 @@ class EventBLL(object):
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -342,14 +345,9 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, scroll="1h", routing=task_id
|
||||
)
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
@@ -377,7 +375,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -393,7 +391,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -413,6 +411,9 @@ class EventBLL(object):
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
@@ -422,13 +423,11 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
must = []
|
||||
if last_iterations_per_plot is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
@@ -451,32 +450,41 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(tasks)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
# scroll id may be missing when queering a totally empty DB
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
@@ -489,6 +497,8 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
@@ -502,20 +512,16 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
@@ -534,27 +540,23 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
@@ -590,7 +592,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@@ -622,14 +624,14 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
@@ -659,7 +661,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@@ -706,7 +708,7 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@@ -727,7 +729,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -737,7 +739,7 @@ class EventBLL(object):
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -759,8 +761,6 @@ class EventBLL(object):
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
)
|
||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Callable, Iterable
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
@@ -16,7 +16,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -30,14 +30,18 @@ class EventType(Enum):
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 50
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
MAX_METRICS_COUNT = 100
|
||||
MAX_VARIANTS_COUNT = 100
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@property
|
||||
def _max_concurrency(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
@@ -51,15 +55,48 @@ class EventMetrics:
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
return self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=[task_id],
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average,
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
@@ -72,159 +109,115 @@ class EventMetrics:
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||
raise errors.BadRequest(
|
||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
|
||||
len(task_ids),
|
||||
)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
override_projection=("id", "name", "company"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
ret = self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average_per_task,
|
||||
)
|
||||
companies = {t.company for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
for metric_data in ret.values():
|
||||
for variant_data in metric_data.values():
|
||||
for task_id, task_data in variant_data.items():
|
||||
task_data["name"] = task_name_by_id[task_id]
|
||||
|
||||
return ret
|
||||
|
||||
TaskMetric = Tuple[str, str, str]
|
||||
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _split_metrics_by_max_aggs_count(
|
||||
self, task_metrics: Sequence[TaskMetric]
|
||||
) -> Iterable[Sequence[TaskMetric]]:
|
||||
"""
|
||||
Return task metrics in groups where amount of task metrics in each group
|
||||
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
|
||||
variants while always preserving all their tasks in the same group
|
||||
"""
|
||||
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield task_metrics
|
||||
return
|
||||
|
||||
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
|
||||
groups = []
|
||||
for group in tm_grouped.values():
|
||||
groups.append(group)
|
||||
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield list(itertools.chain(*groups))
|
||||
groups = []
|
||||
|
||||
if groups:
|
||||
yield list(itertools.chain(*groups))
|
||||
|
||||
return
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
task_ids: Sequence[str],
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
get_func: Callable[
|
||||
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
|
||||
],
|
||||
) -> dict:
|
||||
"""
|
||||
Group metrics per interval length and execute get_func for each group in parallel
|
||||
:param company_id: id of the company
|
||||
:params task_ids: ids of the tasks to collect data for
|
||||
:param samples: maximum number of samples per metric
|
||||
:param get_func: callable that given metric names for the same interval
|
||||
performs histogram aggregation for the metrics and return the aggregated data
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
intervals = self._get_metric_intervals(
|
||||
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
es_index=es_index,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
intervals = list(
|
||||
itertools.chain.from_iterable(
|
||||
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
|
||||
for i, tms in intervals
|
||||
)
|
||||
)
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
|
||||
intervals,
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return ret
|
||||
return res
|
||||
|
||||
def _get_metric_intervals(
|
||||
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return metric variants grouped by interval value with 10% rounding
|
||||
For samples==0 return empty list
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
default_intervals = [(1, [])]
|
||||
if not samples:
|
||||
return default_intervals
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -233,88 +226,78 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return default_intervals
|
||||
return []
|
||||
|
||||
intervals = [
|
||||
(
|
||||
task["key"],
|
||||
metric["key"],
|
||||
variant["key"],
|
||||
self._calculate_metric_interval(variant, samples),
|
||||
)
|
||||
for task in aggs_result["tasks"]["buckets"]
|
||||
for metric in task["metrics"]["buckets"]
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
metric_intervals = []
|
||||
upper_border = 0
|
||||
interval_metrics = None
|
||||
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
|
||||
if not interval_metrics or interval > upper_border:
|
||||
interval_metrics = []
|
||||
metric_intervals.append((interval, interval_metrics))
|
||||
upper_border = interval + int(interval * 0.1)
|
||||
interval_metrics.append((task, metric, variant))
|
||||
|
||||
return metric_intervals
|
||||
|
||||
@staticmethod
|
||||
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(metric_variant, "count/value")
|
||||
if not count or count < samples:
|
||||
return 1
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(metric_variant, "min_index/value", default=0)
|
||||
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
|
||||
return max(1, int(max_index - min_index + 1) // samples)
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
index_range = max_index - min_index + 1
|
||||
interval = max(1, math.ceil(float(index_range) / samples))
|
||||
max_samples = math.ceil(float(index_range) / interval)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
interval,
|
||||
max_samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
Note: the function works with a single task only
|
||||
"""
|
||||
|
||||
assert len(task_ids) == 1
|
||||
interval, task_metrics = metrics_interval
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
@@ -335,61 +318,6 @@ class EventMetrics:
|
||||
]
|
||||
return metrics
|
||||
|
||||
def _get_scalar_average_per_task(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, task_metrics = metrics_interval
|
||||
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {
|
||||
"field": "task",
|
||||
"size": self.MAX_TASKS_COUNT,
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
task["key"]: key.get_iterations_data(task)
|
||||
for task in variant["tasks"]["buckets"]
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
@@ -398,69 +326,55 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_metrics_and_tasks(
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
es_index: str,
|
||||
aggs: dict,
|
||||
task_ids: Sequence[str],
|
||||
task_metrics: Sequence[TaskMetric],
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
if task_metrics:
|
||||
condition = {
|
||||
"should": [
|
||||
self._build_metric_terms(task, metric, variant)
|
||||
for task, metric, variant in task_metrics
|
||||
]
|
||||
}
|
||||
else:
|
||||
condition = {"must": [{"terms": {"task": task_ids}}]}
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"bool": condition},
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
|
||||
"""
|
||||
Build query term for a metric + variant
|
||||
"""
|
||||
return {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence[Tuple]:
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [(tid, []) for tid in task_ids]
|
||||
return {}
|
||||
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_concurrency) as pool:
|
||||
with ThreadPoolExecutor(self._max_concurrency) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
@@ -488,7 +402,7 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
||||
@@ -2,30 +2,12 @@ from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class LogEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
task: str = StringField(required=True)
|
||||
last_min_timestamp: Optional[int] = IntField()
|
||||
last_max_timestamp: Optional[int] = IntField()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state """
|
||||
self.last_min_timestamp = self.last_max_timestamp = None
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
@@ -36,19 +18,8 @@ class TaskEventsResult:
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=LogEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
@@ -56,48 +27,29 @@ class LogEventsIterator:
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
def init_state(state_: LogEventsScrollState):
|
||||
state_.task = task_id
|
||||
|
||||
def validate_state(state_: LogEventsScrollState):
|
||||
"""
|
||||
Checks that the task id stored in the state
|
||||
is equal to the one passed with the current call
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if state_.task != task_id:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task stored in the state does not match the passed one",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
state_.reset()
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res = TaskEventsResult(next_scroll_id=state.id)
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
state=state,
|
||||
)
|
||||
return res
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
es_index,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
state: LogEventsScrollState,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
@@ -111,29 +63,21 @@ class LogEventsIterator:
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": state.task}},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if navigate_earlier and state.last_min_timestamp is not None:
|
||||
es_req["search_after"] = [state.last_min_timestamp]
|
||||
elif not navigate_earlier and state.last_max_timestamp is not None:
|
||||
es_req["search_after"] = [state.last_max_timestamp]
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
if navigate_earlier:
|
||||
state.last_max_timestamp = events[0]["timestamp"]
|
||||
state.last_min_timestamp = events[-1]["timestamp"]
|
||||
else:
|
||||
state.last_min_timestamp = events[0]["timestamp"]
|
||||
state.last_max_timestamp = events[-1]["timestamp"]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
@@ -142,28 +86,29 @@ class LogEventsIterator:
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
hits = es_result["hits"]["hits"]
|
||||
if not hits or len(hits) < 2:
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
|
||||
already_present_ids = set(ev["_id"] for ev in events)
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[
|
||||
*events,
|
||||
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
|
||||
],
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
|
||||
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
|
||||
18
server/bll/model/__init__.py
Normal file
18
server/bll/model/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from database.model.model import Model
|
||||
from database.utils import get_company_or_none_constraint
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
@@ -18,7 +18,6 @@ log = config.logger(__file__)
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
DOC_TYPE = "metrics"
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
@@ -66,7 +65,6 @@ class QueueMetrics:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type=self.EsKeys.DOC_TYPE,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
@@ -93,7 +91,6 @@ class QueueMetrics:
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
doc_type=self.EsKeys.DOC_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -109,7 +106,7 @@ class QueueMetrics:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
|
||||
@@ -19,7 +19,7 @@ from config.info import get_deployment_type
|
||||
from database.model import Company, User
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
from utilities.json import dumps
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from version import __version__ as current_version
|
||||
@@ -237,7 +237,6 @@ class StatisticsReporter:
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,5 +4,4 @@ from .utils import (
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
ParameterKeyEscaper,
|
||||
)
|
||||
|
||||
229
server/bll/task/hyperparams.py
Normal file
229
server/bll/task/hyperparams.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import (
|
||||
HyperParamKey,
|
||||
HyperParamItem,
|
||||
ReplaceHyperparams,
|
||||
Configuration,
|
||||
)
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
class HyperParams:
|
||||
_properties_section = "properties"
|
||||
|
||||
@classmethod
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_params_list(
|
||||
cls, items: Dict[str, Dict[str, ParamsItem]]
|
||||
) -> Sequence[dict]:
|
||||
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
||||
return [
|
||||
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _normalize_params(cls, params: Sequence) -> bool:
|
||||
"""
|
||||
Lower case properties section and return True if it is the only section
|
||||
"""
|
||||
for p in params:
|
||||
if p.section.lower() == cls._properties_section:
|
||||
p.section = cls._properties_section
|
||||
|
||||
return all(p.section == cls._properties_section for p in params)
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{section}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
||||
return {
|
||||
ParameterKeyEscaper.escape(section): {
|
||||
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
||||
for param in params
|
||||
}
|
||||
for section, params in sections.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configurations(
|
||||
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
only = ["id"]
|
||||
if names:
|
||||
only.extend(
|
||||
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
||||
)
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"configuration": [
|
||||
c.to_proper_dict()
|
||||
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
||||
]
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
) -> Dict[str, list]:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration=Sequence[str]
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@staticmethod
|
||||
def _get_task_for_update(
|
||||
company: str, id: str, allow_all_statuses: bool = False
|
||||
) -> Task:
|
||||
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
if task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
201
server/bll/task/param_utils.py
Normal file
201
server/bll/task/param_utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import dpath
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.task.task import Task
|
||||
from tools import safe_get
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
hyperparams_default_section = "Args"
|
||||
hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
if default_section is None:
|
||||
return None, full_name
|
||||
|
||||
section, _, name = full_name.partition("/")
|
||||
if section != tf_define_section:
|
||||
return default_section, full_name
|
||||
|
||||
if not name:
|
||||
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
|
||||
return section, name
|
||||
|
||||
|
||||
def _get_full_param_name(param: dict) -> str:
|
||||
section = param.get("section")
|
||||
if section != tf_define_section:
|
||||
return param["name"]
|
||||
|
||||
return "/".join((section, param["name"]))
|
||||
|
||||
|
||||
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
removed = 0
|
||||
if not data:
|
||||
return removed
|
||||
|
||||
if with_sections:
|
||||
for section, section_data in list(data.items()):
|
||||
removed += _remove_legacy_params(section_data)
|
||||
if not section_data:
|
||||
"""If section is empty after removing legacy params then delete it"""
|
||||
del data[section]
|
||||
else:
|
||||
for key, param in list(data.items()):
|
||||
if param.get("type") == hyperparams_legacy_type:
|
||||
removed += 1
|
||||
del data[key]
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
|
||||
return [
|
||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||
]
|
||||
|
||||
|
||||
def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
"""
|
||||
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
previous_data = previous_task.to_proper_dict().get(new_params_field)
|
||||
removed = _remove_legacy_params(
|
||||
previous_data, with_sections=default_section is not None
|
||||
)
|
||||
if not legacy_params and not removed:
|
||||
# if we only need to delete legacy fields from the db
|
||||
# but they are not there then there is no point to proceed
|
||||
continue
|
||||
|
||||
fields_update = {new_params_field: previous_data}
|
||||
params_unprepare_from_saved(fields_update)
|
||||
fields.update(fields_update)
|
||||
|
||||
for full_name, value in legacy_params.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_params_field, section, name)))
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
"""
|
||||
Unescape all section and param names for hyper params and configuration
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
)
|
||||
|
||||
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
for prefix in (
|
||||
"hyperparams.",
|
||||
"-hyperparams.",
|
||||
"configuration.",
|
||||
"-configuration.",
|
||||
):
|
||||
paths = [
|
||||
_process_path(path) if path.startswith(prefix) else path for path in paths
|
||||
]
|
||||
return paths
|
||||
@@ -5,6 +5,7 @@ from random import random
|
||||
from time import sleep
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
|
||||
import dpath
|
||||
import pymongo.results
|
||||
import six
|
||||
from mongoengine import Q
|
||||
@@ -32,10 +33,11 @@ from database.model.task.task import (
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from services.utils import validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
from utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@@ -83,25 +85,24 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
required_dataset=None,
|
||||
only_fields=None,
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
only_fields = [only_fields]
|
||||
else:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
qs = Task.objects(id=task_id, company=company_id)
|
||||
if only_fields:
|
||||
qs = (
|
||||
qs.only(only_fields)
|
||||
if isinstance(only_fields, string_types)
|
||||
else qs.only(*only_fields)
|
||||
)
|
||||
qs = qs.only(
|
||||
"status", "input"
|
||||
) # make sure all fields we rely on here are also returned
|
||||
task = qs.first()
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -109,17 +110,12 @@ class TaskBLL(object):
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
if required_dataset and required_dataset not in (
|
||||
entry.dataset for entry in task.input.view.entries
|
||||
):
|
||||
raise errors.bad_request.InvalidId(
|
||||
"not in input view", dataset=required_dataset
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(company_id, task_ids, only=None, allow_public=False):
|
||||
def assert_exists(
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
ids = set(task_ids)
|
||||
@@ -130,14 +126,13 @@ class TaskBLL(object):
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
res = q.only(*only)
|
||||
count = len(res)
|
||||
else:
|
||||
count = q.count()
|
||||
res = q.first()
|
||||
if count != len(ids):
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
return res
|
||||
|
||||
if return_tasks:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
@@ -179,21 +174,31 @@ class TaskBLL(object):
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
validate_tags(tags, system_tags)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
execution_overrides["parameters"] = {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
@@ -221,6 +226,8 @@ class TaskBLL(object):
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
@@ -626,28 +633,34 @@ class TaskBLL(object):
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_execution_parameters(
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[str]]:
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"execution.parameters": {"$exists": True, "$gt": {}},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
|
||||
{"$unwind": "$parameters"},
|
||||
{"$group": {"_id": "$parameters.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
@@ -673,7 +686,12 @@ class TaskBLL(object):
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r["_id"])
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import TypeVar, Callable, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from boltons.dictutils import OneToOne
|
||||
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
@@ -172,26 +171,3 @@ def split_by(
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
|
||||
|
||||
class ParameterKeyEscaper:
|
||||
_mapping = OneToOne({".": "%2E", "$": "%24"})
|
||||
|
||||
@classmethod
|
||||
def escape(cls, value):
|
||||
""" Quote a parameter key """
|
||||
value = value.strip().replace("%", "%%")
|
||||
for c, r in cls._mapping.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _unescape(cls, value):
|
||||
for c, r in cls._mapping.inv.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def unescape(cls, value):
|
||||
""" Unquote a quoted parameter key """
|
||||
return "%".join(map(cls._unescape, value.split("%%")))
|
||||
|
||||
@@ -35,14 +35,21 @@ class SetFieldsResolver:
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = set_fields
|
||||
self.fields = {
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
|
||||
@@ -21,6 +21,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
@@ -49,6 +50,7 @@ class WorkerBLL:
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -58,6 +60,7 @@ class WorkerBLL:
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -91,6 +94,7 @@ class WorkerBLL:
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
@@ -113,12 +117,15 @@ class WorkerBLL:
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -129,6 +136,9 @@ class WorkerBLL:
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
@@ -146,6 +156,7 @@ class WorkerBLL:
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
@@ -160,6 +171,12 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
@@ -369,7 +386,6 @@ class WorkerBLL:
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type="stat",
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
|
||||
@@ -25,7 +25,6 @@ class WorkerStats:
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ class WorkerStats:
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]:
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
@@ -87,7 +86,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{request.interval}s",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
@@ -216,7 +215,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
|
||||
@@ -26,6 +26,17 @@
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
# not including the pre_populate
|
||||
db_init_timout: 120
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
@@ -34,11 +45,16 @@
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_file: "/path/to/export.zip"
|
||||
fail_on_error: false
|
||||
elastic {
|
||||
probing {
|
||||
# settings for inital probing of elastic connection
|
||||
max_retries: 4
|
||||
timeout: 30
|
||||
}
|
||||
upgrade_monitoring {
|
||||
v16_migration_verification: true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
hosts: [{host: "127.0.0.1", port: 9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9200}]
|
||||
hosts: [{host:"127.0.0.1", port:9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
|
||||
16
server/config/default/services/auth.conf
Normal file
16
server/config/default/services/auth.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
||||
8
server/config/default/services/projects.conf
Normal file
8
server/config/default/services/projects.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured_order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
||||
@@ -11,4 +11,6 @@ non_responsive_tasks_watchdog {
|
||||
artifacts {
|
||||
update_attempts: 10
|
||||
update_retry_msec: 500
|
||||
}
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
||||
@@ -41,3 +41,7 @@ def get_deployment_type() -> str:
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
es_connection_error = False
|
||||
|
||||
@@ -79,6 +79,10 @@ def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_hosts():
|
||||
return [entry.host for entry in get_entries()]
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first, bucketize
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
@@ -347,6 +348,20 @@ class GetMixin(PropsMixin):
|
||||
return []
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
@@ -483,10 +498,25 @@ class GetMixin(PropsMixin):
|
||||
query=_query, parameters=parameters, override_projection=override_projection
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query,
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
@@ -509,7 +539,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
@@ -517,13 +549,14 @@ class GetMixin(PropsMixin):
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
if only:
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
@@ -559,7 +592,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
@@ -596,16 +631,15 @@ class GetMixin(PropsMixin):
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if only:
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*only) for qs in query_sets]
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields())
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
@@ -616,7 +650,8 @@ class GetMixin(PropsMixin):
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
@@ -728,6 +763,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -19,6 +19,7 @@ class Model(DbModelMixin, Document):
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
@@ -71,3 +72,4 @@ class Model(DbModelMixin, Document):
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine import StringField, DateTimeField
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
@@ -40,3 +40,7 @@ class Project(AttributedDocument):
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
||||
@@ -49,13 +49,13 @@ class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
class Script(EmbeddedDocument, ProperDictMixin):
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(required=True)
|
||||
repository = StringField(default="")
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
entry_point = StringField(default="")
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
diff = StringField()
|
||||
@@ -84,7 +84,23 @@ class Artifact(EmbeddedDocument):
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
section = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
@@ -115,9 +131,12 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
|
||||
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -186,10 +205,15 @@ class Task(AttributedDocument):
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script)
|
||||
script: Script = EmbeddedDocumentField(Script, default=Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
|
||||
@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1:], s, d)
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
@@ -96,6 +96,7 @@ class _ProxyManager:
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
|
||||
orig_field = field
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection).union({"id"})
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [
|
||||
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(
|
||||
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
|
||||
@@ -4,53 +4,54 @@ Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
|
||||
session.mount('http://', adapter)
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
def apply_mappings_to_host(host: str):
|
||||
def _send_mapping(f):
|
||||
def _send_template(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
es_server = host
|
||||
url = f"{es_server}/_template/{f.stem}"
|
||||
|
||||
session.delete(url)
|
||||
r = session.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(data),
|
||||
)
|
||||
return {"mapping": f.stem, "result": r.text}
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
return [
|
||||
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
|
||||
]
|
||||
if key:
|
||||
files = (p / key).glob("*.json")
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("hosts", nargs="+")
|
||||
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
nargs="+",
|
||||
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
for host in parse_args().hosts:
|
||||
print(">>>>> Applying mapping to " + host)
|
||||
res = apply_mappings_to_host(host)
|
||||
print(res)
|
||||
args = parse_args()
|
||||
print(">>>>> Applying mapping to " + str(args.hosts))
|
||||
res = apply_mappings_to_cluster(args.hosts, args.key)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from furl import furl
|
||||
import logging
|
||||
from time import sleep
|
||||
from typing import Type, Optional, Sequence, Any, Union
|
||||
|
||||
import urllib3.exceptions
|
||||
from elasticsearch import Elasticsearch, exceptions
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
from elastic.apply_mappings import apply_mappings_to_cluster
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -15,13 +20,94 @@ class MissingElasticConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
class ElasticConnectionError(Exception):
|
||||
"""
|
||||
Exception when could not connect to elastic during init
|
||||
"""
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionErrorFilter(logging.Filter):
|
||||
def __init__(
|
||||
self,
|
||||
level: Optional[Union[int, str]] = None,
|
||||
err_type: Optional[Type] = None,
|
||||
args_prefix: Optional[Sequence[Any]] = None,
|
||||
):
|
||||
super(ConnectionErrorFilter, self).__init__()
|
||||
if level is None:
|
||||
self.level = None
|
||||
else:
|
||||
try:
|
||||
self.level = int(level)
|
||||
except ValueError:
|
||||
self.level = logging.getLevelName(level)
|
||||
|
||||
self.err_type = err_type
|
||||
self.args = args_prefix and tuple(args_prefix)
|
||||
self.last_blocked = None
|
||||
|
||||
def filter(self, record):
|
||||
try:
|
||||
allow = (
|
||||
(self.err_type is None or record.exc_info[0] != self.err_type)
|
||||
and (self.level is None or record.levelno != self.level)
|
||||
and (self.args is None or record.args[: len(self.args)] != self.args)
|
||||
)
|
||||
if not allow:
|
||||
self.last_blocked = record
|
||||
return allow
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def check_elastic_empty() -> bool:
|
||||
"""
|
||||
Check for elasticsearch connection
|
||||
Use probing settings and not the default es cluster ones
|
||||
so that we can handle correctly the connection rejects due to ES not fully started yet
|
||||
:return:
|
||||
"""
|
||||
cluster_conf = es_factory.get_cluster_config("events")
|
||||
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
|
||||
timeout = config.get("apiserver.elastic.probing.timeout", 30)
|
||||
|
||||
es_logger = logging.getLogger("elasticsearch")
|
||||
log_filter = ConnectionErrorFilter(
|
||||
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
|
||||
)
|
||||
|
||||
try:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
return True
|
||||
except exceptions.ConnectionError as ex:
|
||||
if retry >= max_retries - 1:
|
||||
raise ElasticConnectionError(
|
||||
f"Error connecting to Elasticsearch: {str(ex)}"
|
||||
)
|
||||
log.warn(
|
||||
f"Could not connect to ElasticSearch Service. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
|
||||
)
|
||||
sleep(timeout)
|
||||
finally:
|
||||
es_logger.removeFilter(log_filter)
|
||||
|
||||
|
||||
def init_es_data():
|
||||
for name in es_factory.get_all_cluster_names():
|
||||
cluster_conf = es_factory.get_cluster_config(name)
|
||||
hosts_config = cluster_conf.get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration(f"for cluster '{name}'")
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"_routing": {
|
||||
"required": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": { "type": "date" },
|
||||
"task": { "type": "keyword" },
|
||||
"type": { "type": "keyword" },
|
||||
"worker": { "type": "keyword" },
|
||||
"timestamp": { "type": "date" },
|
||||
"iter": { "type": "long" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
40
server/elastic/mappings/events/events.json
Normal file
40
server/elastic/mappings/events/events.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
15
server/elastic/mappings/events/events_log.json
Normal file
15
server/elastic/mappings/events/events_log.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"index_patterns": "events-log-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"msg": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"level": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
server/elastic/mappings/events/events_plot.json
Normal file
12
server/elastic/mappings/events/events_plot.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"index_patterns": "events-plot-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"index_patterns": "events-training_debug_image-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"url": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-log-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"msg": { "type":"text", "index": false },
|
||||
"level": { "type":"keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"template": "events-plot-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"plot_str": { "type":"text", "index": false }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"template": "events-training_debug_image-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"key": { "type": "keyword" },
|
||||
"url": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
{
|
||||
"template": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"metrics": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"template": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"stat": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": { "type": "date" },
|
||||
"worker": { "type": "keyword" },
|
||||
"category": { "type": "keyword" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" },
|
||||
"unit": { "type": "keyword" },
|
||||
"task": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
25
server/elastic/mappings/workers/queue_metrics.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
server/elastic/mappings/workers/worker_stats.json
Normal file
37
server/elastic/mappings/workers/worker_stats.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"category": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"unit": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,6 +65,10 @@ def connect(cluster_name):
|
||||
return _instances[cluster_name]
|
||||
|
||||
|
||||
def get_all_cluster_names():
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
|
||||
def get_cluster_config(cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.auth import Role
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
from .migration import _apply_migrations
|
||||
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
|
||||
from .pre_populate import PrePopulate
|
||||
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
|
||||
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
@@ -11,33 +13,51 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
log = config.logger(__package__)
|
||||
|
||||
|
||||
def _pre_populate(company_id: str, zip_file: str):
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Invalid pre-populate zip file: {zip_file}"
|
||||
if config.get("apiserver.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
log.info(f"Pre-populating using {zip_file}")
|
||||
|
||||
PrePopulate.import_from_zip(
|
||||
zip_file,
|
||||
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
|
||||
if isinstance(zip_files, str):
|
||||
zip_files = [zip_files]
|
||||
for p in map(Path, zip_files):
|
||||
if p.is_file():
|
||||
yield p
|
||||
if p.is_dir():
|
||||
yield from p.glob("*.zip")
|
||||
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
|
||||
|
||||
|
||||
def pre_populate_data():
|
||||
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
|
||||
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
|
||||
|
||||
PrePopulate.update_featured_projects_order()
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
empty_dbs = _apply_migrations(log)
|
||||
_apply_migrations(log)
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company(log)
|
||||
company_id = _ensure_company(get_default_company(), "trains", log)
|
||||
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
|
||||
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
|
||||
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
|
||||
PrePopulate.import_from_zip(zip_file, user_id=user_id)
|
||||
|
||||
fixed_mode = FixedUser.enabled()
|
||||
|
||||
for user, credentials in config.get("secure.credentials", {}).items():
|
||||
@@ -56,9 +76,13 @@ def init_mongo_data():
|
||||
if fixed_mode:
|
||||
log.info("Fixed users mode is enabled")
|
||||
FixedUser.validate()
|
||||
|
||||
if FixedUser.guest_enabled():
|
||||
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
|
||||
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
ensure_fixed_user(user, company_id, log=log)
|
||||
ensure_fixed_user(user, log=log)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
except Exception as ex:
|
||||
|
||||
@@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
|
||||
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger) -> bool:
|
||||
def check_mongo_empty() -> bool:
|
||||
return not all(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
|
||||
def get_last_server_version() -> Version:
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
return previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger):
|
||||
"""
|
||||
Apply migrations as found in the migration dir.
|
||||
Returns a boolean indicating whether the database was empty prior to migration.
|
||||
@@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
|
||||
empty_dbs = not any(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
empty_dbs = check_mongo_empty()
|
||||
last_version = get_last_server_version()
|
||||
|
||||
try:
|
||||
new_scripts = {
|
||||
@@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
|
||||
).save()
|
||||
|
||||
log.info("Finished mongodb migrations")
|
||||
|
||||
return empty_dbs
|
||||
|
||||
@@ -1,31 +1,390 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from os.path import splitext
|
||||
from typing import List, Optional, Any, Type, Set, Dict
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Optional,
|
||||
Any,
|
||||
Type,
|
||||
Set,
|
||||
Dict,
|
||||
Sequence,
|
||||
Tuple,
|
||||
BinaryIO,
|
||||
Union,
|
||||
Mapping,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import dpath
|
||||
import mongoengine
|
||||
from tqdm import tqdm
|
||||
from boltons.iterutils import chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from bll.event import EventBLL
|
||||
from bll.task.param_utils import (
|
||||
split_param_name,
|
||||
hyperparams_default_section,
|
||||
hyperparams_legacy_type,
|
||||
)
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model import EntityVisibility
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from database.utils import get_options
|
||||
from tools import safe_get
|
||||
from utilities import json
|
||||
from .user import _ensure_backend_user
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
||||
cls._export(zfile, experiments, projects)
|
||||
event_bll = EventBLL()
|
||||
events_file_suffix = "_events"
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
metadata_filename = "metadata.json"
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
artifacts_ext = ".artifacts"
|
||||
img_source_regex = re.compile(
|
||||
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
self.file = file
|
||||
self.empty = True
|
||||
|
||||
def __enter__(self):
|
||||
self._write("[")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self._write("\n]")
|
||||
|
||||
def _write(self, data: str):
|
||||
self.file.write(data.encode("utf-8"))
|
||||
|
||||
def write(self, line: str):
|
||||
if not self.empty:
|
||||
self._write(",")
|
||||
self._write("\n" + line)
|
||||
self.empty = False
|
||||
|
||||
@staticmethod
|
||||
def _get_last_update_time(entity) -> datetime:
|
||||
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(cls, filename: str, user_id: str = None):
|
||||
def _check_for_update(
|
||||
cls, map_file: Path, entities: dict, metadata_hash: str
|
||||
) -> Tuple[bool, Sequence[str]]:
|
||||
if not map_file.is_file():
|
||||
return True, []
|
||||
|
||||
files = []
|
||||
try:
|
||||
map_data = json.loads(map_file.read_text())
|
||||
files = map_data.get("files", [])
|
||||
for file in files:
|
||||
if not Path(file).is_file():
|
||||
return True, files
|
||||
|
||||
new_times = {
|
||||
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
||||
for item in chain.from_iterable(entities.values())
|
||||
}
|
||||
old_times = map_data.get("entities", {})
|
||||
|
||||
if set(new_times.keys()) != set(old_times.keys()):
|
||||
return True, files
|
||||
|
||||
for id_, new_timestamp in new_times.items():
|
||||
if new_timestamp != old_times[id_]:
|
||||
return True, files
|
||||
|
||||
if metadata_hash != map_data.get("metadata_hash", ""):
|
||||
return True, files
|
||||
|
||||
except Exception as ex:
|
||||
print("Error reading map file. " + str(ex))
|
||||
return True, files
|
||||
|
||||
return False, files
|
||||
|
||||
@classmethod
|
||||
def _write_update_file(
|
||||
cls,
|
||||
map_file: Path,
|
||||
entities: dict,
|
||||
created_files: Sequence[str],
|
||||
metadata_hash: str,
|
||||
):
|
||||
map_file.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
files=created_files,
|
||||
entities={
|
||||
entity.id: cls._get_last_update_time(entity)
|
||||
for entity in chain.from_iterable(entities.values())
|
||||
},
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
||||
def is_fileserver_link(a: str) -> bool:
|
||||
a = a.lower()
|
||||
if a.startswith("https://files."):
|
||||
return True
|
||||
if a.startswith("http"):
|
||||
parsed = urlparse(a)
|
||||
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
|
||||
"8081"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
|
||||
print(
|
||||
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
|
||||
)
|
||||
|
||||
return fileserver_links
|
||||
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
artifacts_path: str = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
) -> Sequence[str]:
|
||||
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||
raise ValueError("Invalid task statuses")
|
||||
|
||||
file = Path(filename)
|
||||
entities = cls._resolve_entities(
|
||||
experiments=experiments, projects=projects, task_statuses=task_statuses
|
||||
)
|
||||
|
||||
hash_ = hashlib.md5()
|
||||
if metadata:
|
||||
meta_str = json.dumps(metadata)
|
||||
hash_.update(meta_str.encode())
|
||||
metadata_hash = hash_.hexdigest()
|
||||
else:
|
||||
meta_str, metadata_hash = "", ""
|
||||
|
||||
map_file = file.with_suffix(".map")
|
||||
updated, old_files = cls._check_for_update(
|
||||
map_file, entities=entities, metadata_hash=metadata_hash
|
||||
)
|
||||
if not updated:
|
||||
print(f"There are no updates from the last export")
|
||||
return old_files
|
||||
|
||||
for old in old_files:
|
||||
old_path = Path(old)
|
||||
if old_path.is_file():
|
||||
old_path.unlink()
|
||||
|
||||
with ZipFile(file, **cls.zip_args) as zfile:
|
||||
if metadata:
|
||||
zfile.writestr(cls.metadata_filename, meta_str)
|
||||
artifacts = cls._export(
|
||||
zfile,
|
||||
entities=entities,
|
||||
hash_=hash_,
|
||||
tag_entities=tag_exported_entities,
|
||||
)
|
||||
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
||||
file.replace(file_with_hash)
|
||||
created_files = [str(file_with_hash)]
|
||||
|
||||
artifacts = cls._filter_artifacts(artifacts)
|
||||
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
|
||||
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
|
||||
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
||||
created_files.append(str(artifacts_file))
|
||||
|
||||
cls._write_update_file(
|
||||
map_file,
|
||||
entities=entities,
|
||||
created_files=created_files,
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(
|
||||
cls,
|
||||
filename: str,
|
||||
artifacts_path: str,
|
||||
company_id: Optional[str] = None,
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
):
|
||||
metadata = None
|
||||
|
||||
with ZipFile(filename) as zfile:
|
||||
cls._import(zfile, user_id)
|
||||
try:
|
||||
with zfile.open(cls.metadata_filename) as f:
|
||||
metadata = json.loads(f.read())
|
||||
|
||||
meta_public = metadata.get("public")
|
||||
if company_id is None and meta_public is not None:
|
||||
company_id = "" if meta_public else get_default_company()
|
||||
|
||||
if not user_id:
|
||||
meta_user_id = metadata.get("user_id", "")
|
||||
meta_user_name = metadata.get("user_name", "")
|
||||
user_id, user_name = meta_user_id, meta_user_name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
# Make sure we won't end up with an invalid company ID
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
user_id = _ensure_backend_user(
|
||||
user_id=user_id, user_name=user_name, company_id="",
|
||||
)
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
|
||||
if artifacts_file.is_file():
|
||||
print(f"Unzipping artifacts into {artifacts_path}")
|
||||
with ZipFile(artifacts_file) as zfile:
|
||||
zfile.extractall(artifacts_path)
|
||||
|
||||
@classmethod
|
||||
def upgrade_zip(cls, filename) -> Sequence:
|
||||
hash_ = hashlib.md5()
|
||||
task_file = cls._get_base_filename(Task) + ".json"
|
||||
temp_file = Path("temp.zip")
|
||||
file = Path(filename)
|
||||
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
|
||||
for file_info in reader.filelist:
|
||||
if file_info.orig_filename == task_file:
|
||||
with reader.open(file_info) as f:
|
||||
content = cls._upgrade_tasks(f)
|
||||
else:
|
||||
content = reader.read(file_info)
|
||||
writer.writestr(file_info, content)
|
||||
hash_.update(content)
|
||||
|
||||
base_file_name, _, old_hash = file.stem.rpartition("_")
|
||||
new_hash = hash_.hexdigest()
|
||||
if old_hash == new_hash:
|
||||
print(f"The file {filename} was not updated")
|
||||
temp_file.unlink()
|
||||
return []
|
||||
|
||||
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
|
||||
temp_file.replace(new_file)
|
||||
upadated = [str(new_file)]
|
||||
|
||||
artifacts_file = file.with_suffix(cls.artifacts_ext)
|
||||
if artifacts_file.is_file():
|
||||
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
|
||||
artifacts_file.replace(new_artifacts)
|
||||
upadated.append(str(new_artifacts))
|
||||
|
||||
print(f"File {str(file)} replaced with {str(new_file)}")
|
||||
file.unlink()
|
||||
|
||||
return upadated
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict):
|
||||
for old_param_field, new_param_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy = safe_get(task_data, old_param_field)
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_param_field, section, name)))
|
||||
if not safe_get(task_data, new_path):
|
||||
new_param = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(task_data, new_path, new_param)
|
||||
dpath.delete(task_data, old_param_field)
|
||||
|
||||
@classmethod
|
||||
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
|
||||
"""
|
||||
Build content array that contains fixed tasks from the passed file
|
||||
For each task the old execution.parameters and model.design are
|
||||
converted to the new structure.
|
||||
The fix is done on Task objects (not the dictionary) so that
|
||||
the fields are serialized back in the same order as they were in the original file
|
||||
"""
|
||||
with BytesIO() as temp:
|
||||
with cls.JsonLinesWriter(temp) as w:
|
||||
for line in cls.json_lines(f):
|
||||
task_data = Task.from_json(line).to_proper_dict()
|
||||
cls._upgrade_task_data(task_data)
|
||||
new_task = Task(**task_data)
|
||||
w.write(new_task.to_json())
|
||||
return temp.getvalue()
|
||||
|
||||
@classmethod
|
||||
def update_featured_projects_order(cls):
|
||||
featured_order = config.get("services.projects.featured_order", [])
|
||||
if not featured_order:
|
||||
return
|
||||
|
||||
def get_index(p: Project):
|
||||
for index, entry in enumerate(featured_order):
|
||||
if (
|
||||
entry.get("id", None) == p.id
|
||||
or entry.get("name", None) == p.name
|
||||
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
|
||||
):
|
||||
return index
|
||||
return 999
|
||||
|
||||
for project in Project.get_many_public(projection=["id", "name"]):
|
||||
featured_index = get_index(project)
|
||||
Project.objects(id=project.id).update(featured=featured_index)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
||||
) -> List[Any]:
|
||||
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
||||
) -> Sequence[Any]:
|
||||
ids = set(ids)
|
||||
items = list(cls.objects(id__in=list(ids)))
|
||||
resolved = {i.id for i in items}
|
||||
@@ -43,20 +402,24 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls, experiments: List[str] = None, projects: List[str] = None
|
||||
cls,
|
||||
experiments: Sequence[str] = None,
|
||||
projects: Sequence[str] = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
|
||||
entities = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
entities[Project].update(cls._resolve_type(Project, projects))
|
||||
print("--> Reading project experiments...")
|
||||
objs = Task.objects(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
||||
query = Q(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
)
|
||||
if task_statuses:
|
||||
query &= Q(status__in=list(set(task_statuses)))
|
||||
objs = Task.objects(query)
|
||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||
|
||||
if experiments:
|
||||
@@ -69,85 +432,297 @@ class PrePopulate:
|
||||
project_ids = {p.id for p in entities[Project]}
|
||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
for task in entities[Task]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[Model] = set(Model.objects(id__in=list(model_ids)))
|
||||
|
||||
return entities
|
||||
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task):
|
||||
from database.model.task.task import TaskStatus
|
||||
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
|
||||
if not tags:
|
||||
return tags
|
||||
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
|
||||
|
||||
task.completed = None
|
||||
task.started = None
|
||||
if task.execution:
|
||||
task.execution.model = None
|
||||
task.execution.model_desc = None
|
||||
task.execution.model_labels = None
|
||||
if task.output:
|
||||
task.output.model = None
|
||||
@classmethod
|
||||
def _cleanup_model(cls, model: Model):
|
||||
model.company = ""
|
||||
model.user = ""
|
||||
model.tags = cls._filter_out_export_tags(model.tags)
|
||||
|
||||
task.status = TaskStatus.created
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task: Task):
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.created = datetime.utcnow()
|
||||
task.last_iteration = 0
|
||||
task.last_update = task.created
|
||||
task.status_changed = task.created
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
task.company = ""
|
||||
task.tags = cls._filter_out_export_tags(task.tags)
|
||||
if task.output:
|
||||
task.output.destination = None
|
||||
|
||||
@classmethod
|
||||
def _cleanup_project(cls, project: Project):
|
||||
project.user = ""
|
||||
project.company = ""
|
||||
project.tags = cls._filter_out_export_tags(project.tags)
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
from database.model.task.task import Task
|
||||
if entity_cls == Task:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == Model:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == Project:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
|
||||
try:
|
||||
for item in items:
|
||||
item.update(upsert=False, tags=sorted(item.tags + [tag]))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _export_task_events(
|
||||
cls, task: Task, base_filename: str, writer: ZipFile, hash_
|
||||
) -> Sequence[str]:
|
||||
artifacts = []
|
||||
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
|
||||
print(f"Writing task events into {writer.filename}:{filename}")
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
scroll_id = None
|
||||
while True:
|
||||
res = cls.event_bll.get_task_events(
|
||||
task.company, task.id, scroll_id=scroll_id
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
scroll_id = res.next_scroll_id
|
||||
for event in res.events:
|
||||
event_type = event.get("type")
|
||||
if event_type == "training_debug_image":
|
||||
url = cls._get_fixed_url(event.get("url"))
|
||||
if url:
|
||||
event["url"] = url
|
||||
artifacts.append(url)
|
||||
elif event_type == "plot":
|
||||
plot_str: str = event.get("plot_str", "")
|
||||
for match in cls.img_source_regex.findall(plot_str):
|
||||
url = cls._get_fixed_url(match)
|
||||
if match != url:
|
||||
plot_str = plot_str.replace(match, url)
|
||||
artifacts.append(url)
|
||||
w.write(json.dumps(event))
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
|
||||
if not (url and url.lower().startswith("s3://")):
|
||||
return url
|
||||
try:
|
||||
fixed = furl(url)
|
||||
fixed.scheme = "https"
|
||||
fixed.host += ".s3.amazonaws.com"
|
||||
return fixed.url
|
||||
except Exception as ex:
|
||||
print(f"Failed processing link {url}. " + str(ex))
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def _export_entity_related_data(
|
||||
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
|
||||
):
|
||||
if entity_cls == Task:
|
||||
return [
|
||||
*cls._get_task_output_artifacts(entity),
|
||||
*cls._export_task_events(entity, base_filename, writer, hash_),
|
||||
]
|
||||
|
||||
if entity_cls == Model:
|
||||
entity.uri = cls._get_fixed_url(entity.uri)
|
||||
return [entity.uri] if entity.uri else []
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
|
||||
if not task.execution.artifacts:
|
||||
return []
|
||||
|
||||
for a in task.execution.artifacts:
|
||||
if a.mode == ArtifactModes.output:
|
||||
a.uri = cls._get_fixed_url(a.uri)
|
||||
|
||||
return [
|
||||
a.uri
|
||||
for a in task.execution.artifacts
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _export_artifacts(
|
||||
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
|
||||
):
|
||||
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
|
||||
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
|
||||
for path in unique_paths:
|
||||
path = path.lstrip("/")
|
||||
full_path = os.path.join(artifacts_path, path)
|
||||
if os.path.isfile(full_path):
|
||||
writer.write(full_path, path)
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@staticmethod
|
||||
def _get_base_filename(cls_: type):
|
||||
return f"{cls_.__module__}.{cls_.__name__}"
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
entities = cls._resolve_entities(experiments, projects)
|
||||
|
||||
for cls_, items in entities.items():
|
||||
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
if not items:
|
||||
continue
|
||||
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
)
|
||||
)
|
||||
filename = base_filename + ".json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with writer.open(filename, "w") as f:
|
||||
f.write("[\n".encode("utf-8"))
|
||||
last = len(items) - 1
|
||||
for i, item in enumerate(items):
|
||||
cls._cleanup_entity(cls_, item)
|
||||
f.write(item.to_json().encode("utf-8"))
|
||||
if i != last:
|
||||
f.write(",".encode("utf-8"))
|
||||
f.write("\n".encode("utf-8"))
|
||||
f.write("]\n".encode("utf-8"))
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for item in items:
|
||||
cls._cleanup_entity(cls_, item)
|
||||
w.write(item.to_json())
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
writer.writestr(filename, data)
|
||||
|
||||
if tag_entities:
|
||||
cls._add_tag(items, now.strftime(cls.export_tag))
|
||||
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def _import(reader: ZipFile, user_id: str = None):
|
||||
for file_info in reader.filelist:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
def json_lines(file: BinaryIO):
|
||||
for line in file:
|
||||
clean = (
|
||||
line.decode("utf-8")
|
||||
.rstrip("\r\n")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip(",]")
|
||||
.strip()
|
||||
)
|
||||
if not clean:
|
||||
continue
|
||||
yield clean
|
||||
|
||||
with reader.open(file_info) as f:
|
||||
for item in tqdm(
|
||||
f.readlines(),
|
||||
desc=f"Writing {cls_.__name__.lower()}s into database",
|
||||
unit="doc",
|
||||
):
|
||||
item = (
|
||||
item.decode("utf-8")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip("]")
|
||||
.rstrip(",")
|
||||
.strip()
|
||||
)
|
||||
if not item:
|
||||
continue
|
||||
doc = cls_.from_json(item)
|
||||
if user_id is not None and hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.save(force_insert=True)
|
||||
@classmethod
|
||||
def _import(
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
company_id: str = "",
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
Start from entities since event import will require the tasks already in DB
|
||||
"""
|
||||
event_file_ending = cls.events_file_suffix + ".json"
|
||||
entity_files = (
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
)
|
||||
event_files = (
|
||||
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
for files, reader_func in (
|
||||
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
|
||||
(event_files, cls._import_events),
|
||||
):
|
||||
for file_info in files:
|
||||
with reader.open(file_info) as f:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
reader_func(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(
|
||||
cls,
|
||||
f: BinaryIO,
|
||||
full_name: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
):
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, Project):
|
||||
override_project_name = metadata.get("project_name", None)
|
||||
if override_project_name:
|
||||
if override_project_count:
|
||||
override_project_name = (
|
||||
f"{override_project_name} {override_project_count + 1}"
|
||||
)
|
||||
override_project_count += 1
|
||||
doc.name = override_project_name
|
||||
|
||||
doc.logo_url = metadata.get("logo_url", None)
|
||||
doc.logo_blob = metadata.get("logo_blob", None)
|
||||
|
||||
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
||||
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
)
|
||||
|
||||
doc.save()
|
||||
|
||||
if isinstance(doc, Task):
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
|
||||
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
)
|
||||
|
||||
@@ -58,15 +58,23 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
|
||||
if User.objects(id=user.user_id).first():
|
||||
def ensure_fixed_user(user: FixedUser, log: Logger):
|
||||
db_user = User.objects(company=user.company, id=user.user_id).first()
|
||||
if db_user:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
log.info(f"Updating user name: {user.name}")
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
data["role"] = Role.guest if user.is_guest else Role.user
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
|
||||
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
|
||||
|
||||
return _ensure_backend_user(user.user_id, company_id, user.name)
|
||||
return _ensure_backend_user(user.user_id, user.company, user.name)
|
||||
|
||||
@@ -3,7 +3,6 @@ from uuid import uuid4
|
||||
|
||||
from bll.queue import QueueBLL
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings, SettingKeys
|
||||
@@ -11,13 +10,11 @@ from database.model.settings import Settings, SettingKeys
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def _ensure_company(log: Logger):
|
||||
company_id = get_default_company()
|
||||
def _ensure_company(company_id, company_name, log: Logger):
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
|
||||
36
server/mongo/migrations/0.16.0.py
Normal file
36
server/mongo/migrations/0.16.0.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
from bll.task.param_utils import (
|
||||
hyperparams_legacy_type,
|
||||
hyperparams_default_section,
|
||||
split_param_name,
|
||||
)
|
||||
from tools import safe_get
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
hyperparam_fields = ("execution.parameters", "hyperparams")
|
||||
configuration_fields = ("execution.model_desc", "configuration")
|
||||
collection: Collection = db["task"]
|
||||
for doc in collection.find(projection=hyperparam_fields + configuration_fields):
|
||||
set_commands = {}
|
||||
for (old_field, new_field), default_section in zip(
|
||||
(hyperparam_fields, configuration_fields),
|
||||
(hyperparams_default_section, None),
|
||||
):
|
||||
legacy = safe_get(doc, old_field, separator=".")
|
||||
if not legacy:
|
||||
continue
|
||||
for full_name, value in legacy.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_field, section, name)))
|
||||
# if safe_get(doc, new_path) is not None:
|
||||
# continue
|
||||
new_value = dict(
|
||||
name=name, type=hyperparams_legacy_type, value=str(value)
|
||||
)
|
||||
if section is not None:
|
||||
new_value["section"] = section
|
||||
set_commands[".".join(new_path)] = new_value
|
||||
if set_commands:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||
11
server/mongo/migrations/0.16.1.py
Normal file
11
server/mongo/migrations/0.16.1.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collection: Collection = db["project"]
|
||||
featured = "featured"
|
||||
query = {featured: {"$exists": False}}
|
||||
for doc in collection.find(filter=query, projection=()):
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {featured: 9999}},
|
||||
)
|
||||
@@ -1,7 +1,8 @@
|
||||
attrs>=19.1.0
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
fastjsonschema>=2.8
|
||||
Flask-Compress>=1.4.0
|
||||
Flask-Cors>=3.0.5
|
||||
@@ -24,7 +25,7 @@ python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.0,<3
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
||||
@@ -328,6 +328,11 @@ fixed_users_mode {
|
||||
description: "Fixed users mode enabled"
|
||||
type: boolean
|
||||
}
|
||||
server_errors {
|
||||
description: "Server initialization errors"
|
||||
type: object
|
||||
additionalProperties: True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
16
server/schema/services/debug.conf
Normal file
16
server/schema/services/debug.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
_description: "debugging utilities"
|
||||
ping {
|
||||
authorize: false
|
||||
"2.9" {
|
||||
description: "Ping server"
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties: {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -530,59 +530,56 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
// "2.7" {
|
||||
// description: "Get 'log' events for this task"
|
||||
// request {
|
||||
// type: object
|
||||
// required: [
|
||||
// task
|
||||
// ]
|
||||
// properties {
|
||||
// task {
|
||||
// type: string
|
||||
// description: "Task ID"
|
||||
// }
|
||||
// batch_size {
|
||||
// type: integer
|
||||
// description: "The amount of log events to return"
|
||||
// }
|
||||
// navigate_earlier {
|
||||
// type: boolean
|
||||
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
||||
// }
|
||||
// refresh {
|
||||
// type: boolean
|
||||
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
||||
// }
|
||||
// scroll_id {
|
||||
// type: string
|
||||
// description: "Scroll ID of previous call (used for getting more results)"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// response {
|
||||
// type: object
|
||||
// properties {
|
||||
// events {
|
||||
// type: array
|
||||
// items { type: object }
|
||||
// description: "Log items list"
|
||||
// }
|
||||
// returned {
|
||||
// type: integer
|
||||
// description: "Number of log events returned"
|
||||
// }
|
||||
// total {
|
||||
// type: number
|
||||
// description: "Total number of log events available for this query"
|
||||
// }
|
||||
// scroll_id {
|
||||
// type: string
|
||||
// description: "Scroll ID for getting more results"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
"2.9" {
|
||||
description: "Get 'log' events for this task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
batch_size {
|
||||
type: integer
|
||||
description: "The amount of log events to return"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order, unless order='asc'). Otherwise from the earliest to the latest ones (in timestamp ascending order, unless order='desc'). The default is True"
|
||||
}
|
||||
from_timestamp {
|
||||
type: number
|
||||
description: "Epoch time in UTC ms to use as the navigation start. Optional. If not provided, reference timestamp is determined by the 'navigate_earlier' parameter (if true, reference timestamp is the last timestamp and if false, reference timestamp is the first timestamp)"
|
||||
}
|
||||
order {
|
||||
type: string
|
||||
description: "If set, changes the order in which log events are returned based on the value of 'navigate_earlier'"
|
||||
enum: [asc, desc]
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
type: array
|
||||
items { type: object }
|
||||
description: "Log items list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of log events returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of log events available for this query"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_events {
|
||||
"2.1" {
|
||||
@@ -856,7 +853,7 @@
|
||||
description: "Task ID"
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
@@ -894,7 +891,7 @@
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs"
|
||||
description: "List of task Task IDs. Maximum amount of tasks is 10"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
@@ -902,7 +899,7 @@
|
||||
}
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return. Optional, the default value is 6000"
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -405,6 +405,11 @@ get_all_ex {
|
||||
enum: [ active, archived ]
|
||||
default: active
|
||||
}
|
||||
non_public {
|
||||
description: "Return only non-public projects"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -527,8 +532,8 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.2" {
|
||||
description: """Get a list of all hyper parameter names used in tasks within the given project."""
|
||||
"2.9" {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
@@ -552,9 +557,9 @@ get_hyper_parameters {
|
||||
type: object
|
||||
properties {
|
||||
parameters {
|
||||
description: "A list of hyper parameter names"
|
||||
description: "A list of parameter sections and names"
|
||||
type: array
|
||||
items {type: string}
|
||||
items {type: object}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
@@ -568,6 +573,7 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
@@ -575,10 +581,61 @@ get_task_tags {
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company projects to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public projects to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -297,7 +297,80 @@ _definitions {
|
||||
"$ref": "#/definitions/last_metrics_event"
|
||||
}
|
||||
}
|
||||
|
||||
params_item {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. The combination of section and name should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
configuration_item {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Name of the parameter. Should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
param_key {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. If the name is ommitted then the corresponding operation is performed on the whole section"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
section_params {
|
||||
description: "Task section params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/params_item"
|
||||
}
|
||||
}
|
||||
replace_hyperparams_enum {
|
||||
type: string
|
||||
enum: [
|
||||
none,
|
||||
section,
|
||||
all
|
||||
]
|
||||
}
|
||||
task {
|
||||
type: object
|
||||
properties {
|
||||
@@ -418,9 +491,24 @@ _definitions {
|
||||
"$ref": "#/definitions/last_metrics_variants"
|
||||
}
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: "Gets task information"
|
||||
@@ -625,6 +713,20 @@ clone {
|
||||
description: "The project of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
new_task_hyperparams {
|
||||
description: "The hyper params for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
new_task_configuration {
|
||||
description: "The configuration for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
execution_overrides {
|
||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||
"$ref": "#/definitions/execution"
|
||||
@@ -698,6 +800,20 @@ create {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -759,6 +875,20 @@ validate {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
@@ -909,6 +1039,20 @@ edit {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
@@ -1441,4 +1585,263 @@ add_or_update_artifacts {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company tasks to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public tasks to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_hyper_params {
|
||||
"2.9": {
|
||||
description: "Get the list of task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
params {
|
||||
type: object
|
||||
description: "Hyper parameters (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_hyper_params {
|
||||
"2.9" {
|
||||
description: "Add or update task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, hyperparams ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper parameters. The new ones will be added and the already existing ones will be updated"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/params_item"}
|
||||
}
|
||||
replace_hyperparams {
|
||||
description: """Can be set to one of the following:
|
||||
'all' - all the hyper parameters will be replaced with the provided ones
|
||||
'section' - the sections that present in the new parameters will be replaced with the provided parameters
|
||||
'none' (the default value) - only the specific parameters will be updated or added"""
|
||||
"$ref": "#/definitions/replace_hyperparams_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_hyper_params {
|
||||
"2.9": {
|
||||
description: "Delete task hyper parameters"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, hyperparams ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
hyperparams {
|
||||
description: "List of hyper parameters to delete. In case a parameter with an empty name is passed all the section will be deleted"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/param_key" }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_configurations {
|
||||
"2.9": {
|
||||
description: "Get the list of task configurations"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
names {
|
||||
description: "Names of the configuration items to retreive. If not passed or empty then all the configurations will be retreived."
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
configurations {
|
||||
type: object
|
||||
description: "Configurations (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_configuration_names {
|
||||
"2.9": {
|
||||
description: "Get the list of task configuration items names"
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "Task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
configurations {
|
||||
type: object
|
||||
description: "Names of task configuration items (keyed by task ID)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit_configuration {
|
||||
"2.9" {
|
||||
description: "Add or update task configuration"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, configuration ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration items. The new ones will be added and the already existing ones will be updated"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/configuration_item"}
|
||||
}
|
||||
replace_configuration {
|
||||
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_configuration {
|
||||
"2.9": {
|
||||
description: "Delete task configuration items"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, configuration ]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
configuration {
|
||||
description: "List of configuration itemss to delete"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -135,6 +135,10 @@
|
||||
description: "Task currently being run by the worker"
|
||||
"$ref": "#/definitions/current_task_entry"
|
||||
}
|
||||
project {
|
||||
description: "Project in which currently executing task resides"
|
||||
"$ref": "#/definitions/id_name_entry"
|
||||
}
|
||||
queue {
|
||||
description: "Queue from which running task was taken"
|
||||
"$ref": "#/definitions/queue_entry"
|
||||
@@ -144,6 +148,11 @@
|
||||
type: array
|
||||
items { "$ref": "#/definitions/queue_entry" }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,11 +160,11 @@
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Worker ID"
|
||||
description: "ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Worker name"
|
||||
description: "Name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
@@ -301,6 +310,11 @@
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "User tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -363,6 +377,11 @@
|
||||
description: "The machine statistics."
|
||||
"$ref": "#/definitions/machine_stats"
|
||||
}
|
||||
tags {
|
||||
description: "New user tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import atexit
|
||||
from argparse import ArgumentParser
|
||||
from hashlib import md5
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from elastic.initialize import init_es_data
|
||||
from mongo.initialize import init_mongo_data
|
||||
from config import config, info
|
||||
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
|
||||
from mongo.initialize import (
|
||||
init_mongo_data,
|
||||
pre_populate_data,
|
||||
check_mongo_empty,
|
||||
get_last_server_version,
|
||||
)
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from sync import distributed_lock
|
||||
from timing_context import TimingContext
|
||||
from updates import check_updates_thread
|
||||
from utilities import json
|
||||
@@ -33,8 +41,47 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
||||
|
||||
database.initialize()
|
||||
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
# build a key that uniquely identifies specific mongo instance
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
|
||||
upgrade_monitoring = config.get(
|
||||
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
|
||||
)
|
||||
try:
|
||||
empty_es = check_elastic_empty()
|
||||
except ElasticConnectionError as err:
|
||||
if not upgrade_monitoring:
|
||||
raise
|
||||
log.error(err)
|
||||
info.es_connection_error = True
|
||||
|
||||
empty_db = check_mongo_empty()
|
||||
if (
|
||||
upgrade_monitoring
|
||||
and not empty_db
|
||||
and (info.es_connection_error or empty_es)
|
||||
and get_last_server_version() < Version("0.16.0")
|
||||
):
|
||||
log.info(f"ES database seems not migrated")
|
||||
info.missed_es_upgrade = True
|
||||
|
||||
if info.es_connection_error and not info.missed_es_upgrade:
|
||||
raise Exception(
|
||||
"Error starting server: failed connecting to ElasticSearch service"
|
||||
)
|
||||
|
||||
if not info.missed_es_upgrade:
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
|
||||
if (
|
||||
not info.missed_es_upgrade
|
||||
and empty_db
|
||||
and config.get("apiserver.pre_populate.enabled", False)
|
||||
):
|
||||
pre_populate_data()
|
||||
|
||||
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
@@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
|
||||
from .service_repo import ServiceRepo
|
||||
|
||||
|
||||
__all__ = ["endpoint"]
|
||||
__all__ = ["APICall", "endpoint"]
|
||||
|
||||
|
||||
LegacyEndpointFunc = Callable[[APICall], None]
|
||||
|
||||
@@ -69,6 +69,10 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
if fixed_user:
|
||||
if secret_key != fixed_user.password:
|
||||
raise errors.unauthorized.InvalidCredentials('bad username or password')
|
||||
|
||||
if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action):
|
||||
raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest')
|
||||
|
||||
query = Q(id=fixed_user.user_id)
|
||||
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
from typing import Sequence, TypeVar
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
|
||||
T = TypeVar("T", bound="FixedUser")
|
||||
|
||||
|
||||
class FixedUsersError(Exception):
|
||||
pass
|
||||
@@ -21,6 +19,8 @@ class FixedUser:
|
||||
name: str
|
||||
company: str = get_default_company()
|
||||
|
||||
is_guest: bool = False
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest()
|
||||
|
||||
@@ -28,6 +28,10 @@ class FixedUser:
|
||||
def enabled(cls):
|
||||
return config.get("apiserver.auth.fixed_users.enabled", False)
|
||||
|
||||
@classmethod
|
||||
def guest_enabled(cls):
|
||||
return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False)
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
if not cls.enabled():
|
||||
@@ -39,18 +43,50 @@ class FixedUser:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def from_config(cls) -> Sequence[T]:
|
||||
return [
|
||||
# @lru_cache()
|
||||
def from_config(cls) -> Sequence["FixedUser"]:
|
||||
users = [
|
||||
cls(**user) for user in config.get("apiserver.auth.fixed_users.users", [])
|
||||
]
|
||||
|
||||
if cls.guest_enabled():
|
||||
users.insert(
|
||||
0,
|
||||
cls.get_guest_user()
|
||||
)
|
||||
|
||||
return users
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_by_username(cls, username) -> T:
|
||||
def get_by_username(cls, username) -> "FixedUser":
|
||||
return next(
|
||||
(user for user in cls.from_config() if user.username == username), None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def is_guest_endpoint(cls, service, action):
|
||||
"""
|
||||
Validate a potential guest user,
|
||||
This method will verify the user is indeed the guest user,
|
||||
and that the guest user may access the service/action using its username/password
|
||||
"""
|
||||
return any(
|
||||
ep == ".".join((service, action))
|
||||
for ep in config.get("services.auth.fixed_users.guest.allow_endpoints", [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_guest_user(cls) -> Optional["FixedUser"]:
|
||||
if cls.guest_enabled():
|
||||
return cls(
|
||||
is_guest=True,
|
||||
username=config.get("services.auth.fixed_users.guest.username"),
|
||||
password=config.get("services.auth.fixed_users.guest.password"),
|
||||
name=config.get("services.auth.fixed_users.guest.name"),
|
||||
company=config.get("services.auth.fixed_users.guest.default_company"),
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.user_id)
|
||||
|
||||
@@ -16,7 +16,7 @@ from apimodels.auth import (
|
||||
)
|
||||
from apimodels.base import UpdateResponse
|
||||
from bll.auth import AuthBLL
|
||||
from config import config
|
||||
from config import config, info
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from service_repo import APICall, endpoint
|
||||
@@ -176,4 +176,24 @@ def update(call, company_id, _):
|
||||
|
||||
@endpoint("auth.fixed_users_mode")
|
||||
def fixed_users_mode(call: APICall, *_, **__):
|
||||
call.result.data = dict(enabled=FixedUser.enabled())
|
||||
server_errors = {
|
||||
name: error
|
||||
for name, error in zip(
|
||||
("missed_es_upgrade", "es_connection_error"),
|
||||
(info.missed_es_upgrade, info.es_connection_error),
|
||||
)
|
||||
if error
|
||||
}
|
||||
|
||||
data = {
|
||||
"enabled": FixedUser.enabled(),
|
||||
"guest": {"enabled": FixedUser.guest_enabled()},
|
||||
"server_errors": server_errors,
|
||||
}
|
||||
guest_user = FixedUser.get_guest_user()
|
||||
if guest_user:
|
||||
data["guest"]["name"] = guest_user.name
|
||||
data["guest"]["username"] = guest_user.username
|
||||
data["guest"]["password"] = guest_user.password
|
||||
|
||||
call.result.data = data
|
||||
|
||||
6
server/services/debug.py
Normal file
6
server/services/debug.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from service_repo import APICall, endpoint
|
||||
|
||||
|
||||
@endpoint("debug.ping")
|
||||
def ping(call: APICall, _, __):
|
||||
call.result.data = {"msg": "Because it trains cats and dogs"}
|
||||
@@ -12,6 +12,7 @@ from apimodels.events import (
|
||||
IterationEvents,
|
||||
TaskMetricsRequest,
|
||||
LogEventsRequest,
|
||||
LogOrderEnum,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
@@ -24,7 +25,7 @@ event_bll = EventBLL()
|
||||
|
||||
|
||||
@endpoint("events.add")
|
||||
def add(call: APICall, company_id, req_model):
|
||||
def add(call: APICall, company_id, _):
|
||||
data = call.data.copy()
|
||||
allow_locked = data.pop("allow_locked", False)
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
@@ -35,7 +36,7 @@ def add(call: APICall, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.add_batch")
|
||||
def add_batch(call: APICall, company_id, req_model):
|
||||
def add_batch(call: APICall, company_id, _):
|
||||
events = call.batched_data
|
||||
if events is None or len(events) == 0:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
@@ -46,14 +47,16 @@ def add_batch(call: APICall, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
def get_task_log_v1_5(call, company_id, req_model):
|
||||
def get_task_log_v1_5(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
order = call.data.get("order") or "desc"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
order,
|
||||
event_type="log",
|
||||
@@ -66,9 +69,11 @@ def get_task_log_v1_5(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
|
||||
def get_task_log_v1_7(call, company_id, req_model):
|
||||
def get_task_log_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
|
||||
order = call.data.get("order") or "desc"
|
||||
from_ = call.data.get("from") or "head"
|
||||
@@ -78,7 +83,7 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
scroll_order = "asc" if (from_ == "head") else "desc"
|
||||
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id=company_id,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
order=scroll_order,
|
||||
event_type="log",
|
||||
@@ -94,33 +99,40 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
# uncomment this once the front end is ready
|
||||
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
||||
# def get_task_log(call, company_id, req_model: LogEventsRequest):
|
||||
# task_id = req_model.task
|
||||
# task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
#
|
||||
# res = event_bll.log_events_iterator.get_task_events(
|
||||
# company_id=company_id,
|
||||
# task_id=task_id,
|
||||
# batch_size=req_model.batch_size,
|
||||
# navigate_earlier=req_model.navigate_earlier,
|
||||
# refresh=req_model.refresh,
|
||||
# state_id=req_model.scroll_id,
|
||||
# )
|
||||
#
|
||||
# call.result.data = dict(
|
||||
# events=res.events,
|
||||
# returned=len(res.events),
|
||||
# total=res.total_events,
|
||||
# scroll_id=res.next_scroll_id,
|
||||
# )
|
||||
@endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest)
|
||||
def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
task_id = request.task
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=request.batch_size,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
from_timestamp=request.from_timestamp,
|
||||
)
|
||||
|
||||
if (
|
||||
request.order and (
|
||||
(request.navigate_earlier and request.order == LogOrderEnum.asc)
|
||||
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
|
||||
)
|
||||
):
|
||||
res.events.reverse()
|
||||
|
||||
call.result.data = dict(
|
||||
events=res.events, returned=len(res.events), total=res.total_events
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
def download_task_log(call, company_id, req_model):
|
||||
def download_task_log(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
|
||||
line_type = call.data.get("line_type", "json").lower()
|
||||
line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}"))
|
||||
@@ -163,7 +175,7 @@ def download_task_log(call, company_id, req_model):
|
||||
batch_size = 1000
|
||||
while True:
|
||||
log_events, scroll_id, _ = event_bll.scroll_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
order="asc",
|
||||
event_type="log",
|
||||
@@ -173,7 +185,7 @@ def download_task_log(call, company_id, req_model):
|
||||
if not log_events:
|
||||
break
|
||||
for ev in log_events:
|
||||
ev["asctime"] = ev.pop("@timestamp")
|
||||
ev["asctime"] = ev.pop("timestamp")
|
||||
if is_json:
|
||||
ev.pop("type")
|
||||
ev.pop("task")
|
||||
@@ -196,23 +208,27 @@ def download_task_log(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
||||
def get_vector_metrics_and_variants(call, company_id, req_model):
|
||||
def get_vector_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
company_id, task_id, "training_stats_vector"
|
||||
task.company, task_id, "training_stats_vector"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
||||
def get_scalar_metrics_and_variants(call, company_id, req_model):
|
||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
company_id, task_id, "training_stats_scalar"
|
||||
task.company, task_id, "training_stats_scalar"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -222,13 +238,15 @@ def get_scalar_metrics_and_variants(call, company_id, req_model):
|
||||
"events.vector_metrics_iter_histogram",
|
||||
required_fields=["task", "metric", "variant"],
|
||||
)
|
||||
def vector_metrics_iter_histogram(call, company_id, req_model):
|
||||
def vector_metrics_iter_histogram(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
company_id, task_id, metric, variant
|
||||
task.company, task_id, metric, variant
|
||||
)
|
||||
call.result.data = dict(
|
||||
metric=metric, variant=variant, vectors=vectors, iterations=iterations
|
||||
@@ -243,9 +261,11 @@ def get_task_events(call, company_id, _):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
order = call.data.get("order") or "asc"
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=event_type,
|
||||
@@ -262,14 +282,16 @@ def get_task_events(call, company_id, _):
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
|
||||
def get_scalar_metric_data(call, company_id, req_model):
|
||||
def get_scalar_metric_data(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
event_type="training_stats_scalar",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -286,13 +308,15 @@ def get_scalar_metric_data(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
|
||||
def get_task_latest_scalar_values(call, company_id, req_model):
|
||||
def get_task_latest_scalar_values(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
|
||||
company_id, task_id
|
||||
task.company, task_id
|
||||
)
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_index = EventMetrics.get_index_name(task.company, "*")
|
||||
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
@@ -309,11 +333,13 @@ def get_task_latest_scalar_values(call, company_id, req_model):
|
||||
request_data_model=ScalarMetricsIterHistogramRequest,
|
||||
)
|
||||
def scalar_metrics_iter_histogram(
|
||||
call, company_id, req_model: ScalarMetricsIterHistogramRequest
|
||||
call, company_id, request: ScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task_bll.assert_exists(call.identity.company, req_model.task, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, request.task, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||
company_id, task_id=req_model.task, samples=req_model.samples, key=req_model.key
|
||||
task.company, task_id=request.task, samples=request.samples, key=request.key
|
||||
)
|
||||
call.result.data = metrics
|
||||
|
||||
@@ -341,21 +367,27 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
|
||||
def get_multi_task_plots_v1_7(call, company_id, req_model):
|
||||
def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name"),
|
||||
company_id=company_id,
|
||||
only=("id", "name", "company"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.company for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
next(iter(companies)),
|
||||
task_ids,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -385,13 +417,19 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name"),
|
||||
only=("id", "name", "company"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.company for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
next(iter(companies)),
|
||||
task_ids,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -414,12 +452,14 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", required_fields=["task"])
|
||||
def get_task_plots_v1_7(call, company_id, req_model):
|
||||
def get_task_plots_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
# events, next_scroll_id, total_events = event_bll.get_task_events(
|
||||
# company, task_id,
|
||||
# event_type="plot",
|
||||
@@ -429,7 +469,7 @@ def get_task_plots_v1_7(call, company_id, req_model):
|
||||
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -448,14 +488,16 @@ def get_task_plots_v1_7(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
|
||||
def get_task_plots(call, company_id, req_model):
|
||||
def get_task_plots(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
result = event_bll.get_task_plots(
|
||||
company_id,
|
||||
task.company,
|
||||
tasks=[task_id],
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
@@ -473,12 +515,14 @@ def get_task_plots(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.debug_images", required_fields=["task"])
|
||||
def get_debug_images_v1_7(call, company_id, req_model):
|
||||
def get_debug_images_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
# events, next_scroll_id, total_events = event_bll.get_task_events(
|
||||
# company, task_id,
|
||||
# event_type="training_debug_image",
|
||||
@@ -488,7 +532,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
|
||||
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -508,14 +552,16 @@ def get_debug_images_v1_7(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
||||
def get_debug_images_v1_8(call, company_id, req_model):
|
||||
def get_debug_images_v1_8(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
company_id,
|
||||
task.company,
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -540,16 +586,25 @@ def get_debug_images_v1_8(call, company_id, req_model):
|
||||
request_data_model=DebugImagesRequest,
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
|
||||
tasks = set(m.task for m in req_model.metrics)
|
||||
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_ids = {m.task for m in request.metrics}
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id, task_ids=task_ids, allow_public=True, only=("company",)
|
||||
)
|
||||
|
||||
companies = {t.company for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company_id,
|
||||
metrics=[(m.task, m.metric) for m in req_model.metrics],
|
||||
iter_count=req_model.iters,
|
||||
navigate_earlier=req_model.navigate_earlier,
|
||||
refresh=req_model.refresh,
|
||||
state_id=req_model.scroll_id,
|
||||
company_id=next(iter(companies)),
|
||||
metrics=[(m.task, m.metric) for m in request.metrics],
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
refresh=request.refresh,
|
||||
state_id=request.scroll_id,
|
||||
)
|
||||
|
||||
call.result.data_model = DebugImageResponse(
|
||||
@@ -569,12 +624,12 @@ def get_debug_images(call, company_id, req_model: DebugImagesRequest):
|
||||
|
||||
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
|
||||
task_bll.assert_exists(
|
||||
call.identity.company, task_ids=req_model.tasks, allow_public=True
|
||||
)
|
||||
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=request.tasks, allow_public=True, only=("company",)
|
||||
)[0]
|
||||
res = event_bll.metrics.get_tasks_metrics(
|
||||
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
|
||||
task.company, task_ids=request.tasks, event_type=request.event_type
|
||||
)
|
||||
call.result.data = {
|
||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||
@@ -586,7 +641,7 @@ def delete_for_task(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
task_bll.assert_exists(company_id, task_id)
|
||||
task_bll.assert_exists(company_id, task_id, return_tasks=False)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(
|
||||
company_id, task_id, allow_locked=allow_locked
|
||||
|
||||
@@ -5,14 +5,17 @@ from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apierrors.errors.bad_request import InvalidModelId
|
||||
from apimodels.base import UpdateResponse, MakePublicRequest
|
||||
from apimodels.models import (
|
||||
CreateModelRequest,
|
||||
CreateModelResponse,
|
||||
PublishModelRequest,
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
GetFrameworksRequest,
|
||||
)
|
||||
from bll.model import ModelBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
@@ -32,6 +35,7 @@ from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
model_bll = ModelBLL()
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
@@ -107,6 +111,15 @@ def get_all(call: APICall, company_id, _):
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
|
||||
call.result.data = {
|
||||
"frameworks": sorted(
|
||||
model_bll.get_frameworks(company_id, project_ids=request.projects)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
"tags": list,
|
||||
@@ -455,3 +468,21 @@ def update(call: APICall, company_id, _):
|
||||
if del_count:
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
|
||||
|
||||
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Model.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Model.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
|
||||
)
|
||||
|
||||
@@ -8,10 +8,10 @@ from mongoengine import Q
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apierrors.errors.bad_request import InvalidProjectId
|
||||
from apimodels.base import UpdateResponse, MakePublicRequest
|
||||
from apimodels.projects import (
|
||||
GetHyperParamReq,
|
||||
GetHyperParamResp,
|
||||
ProjectReq,
|
||||
ProjectTagsRequest,
|
||||
)
|
||||
@@ -185,6 +185,7 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
|
||||
def get_all_ex(call: APICall):
|
||||
include_stats = call.data.get("include_stats")
|
||||
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
||||
allow_public = not call.data.get("non_public", False)
|
||||
|
||||
if stats_for_state:
|
||||
try:
|
||||
@@ -200,7 +201,7 @@ def get_all_ex(call: APICall):
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
@@ -375,13 +376,12 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.2",
|
||||
min_version="2.9",
|
||||
request_data_model=GetHyperParamReq,
|
||||
response_data_model=GetHyperParamResp,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
page=request.page,
|
||||
@@ -421,3 +421,23 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Project.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Project.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
|
||||
)
|
||||
|
||||
@@ -11,7 +11,8 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
|
||||
from apierrors import errors, APIError
|
||||
from apimodels.base import UpdateResponse, IdResponse
|
||||
from apierrors.errors.bad_request import InvalidTaskId
|
||||
from apimodels.base import UpdateResponse, IdResponse, MakePublicRequest
|
||||
from apimodels.tasks import (
|
||||
StartedResponse,
|
||||
ResetResponse,
|
||||
@@ -31,6 +32,13 @@ from apimodels.tasks import (
|
||||
AddOrUpdateArtifactsResponse,
|
||||
GetTypesRequest,
|
||||
ResetRequest,
|
||||
GetHyperParamsRequest,
|
||||
EditHyperParamsRequest,
|
||||
DeleteHyperParamsRequest,
|
||||
GetConfigurationsRequest,
|
||||
EditConfigurationRequest,
|
||||
DeleteConfigurationRequest,
|
||||
GetConfigurationNamesRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
@@ -40,9 +48,14 @@ from bll.task import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
split_by,
|
||||
ParameterKeyEscaper,
|
||||
)
|
||||
from bll.task.hyperparams import HyperParams
|
||||
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
||||
from bll.task.param_utils import (
|
||||
params_prepare_for_save,
|
||||
params_unprepare_from_saved,
|
||||
escape_paths,
|
||||
)
|
||||
from bll.util import SetFieldsResolver
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@@ -56,9 +69,9 @@ from database.model.task.task import (
|
||||
)
|
||||
from database.utils import get_fields, parse_from_call
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from service_repo.base import PartialVersion
|
||||
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
|
||||
task_fields = set(Task.get_fields())
|
||||
task_script_fields = set(get_fields(Script))
|
||||
@@ -78,10 +91,24 @@ def set_task_status_from_call(
|
||||
task = TaskBLL.get_task_with_access(
|
||||
request.task,
|
||||
company_id=company_id,
|
||||
only=tuple({"status", "project"} | fields_resolver.get_names()),
|
||||
only=tuple(
|
||||
{"status", "project", "started", "duration"} | fields_resolver.get_names()
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
if "duration" not in fields_resolver.get_names():
|
||||
if new_status == Task.started:
|
||||
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
|
||||
elif new_status in (
|
||||
TaskStatus.completed,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.stopped,
|
||||
):
|
||||
fields_resolver.add_fields(
|
||||
duration=int((task.started - datetime.utcnow()).total_seconds())
|
||||
)
|
||||
|
||||
status_reason = request.status_reason
|
||||
status_message = request.status_message
|
||||
force = request.force
|
||||
@@ -105,30 +132,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
|
||||
|
||||
def escape_execution_parameters(call: APICall):
|
||||
default_prefix = "execution.parameters."
|
||||
|
||||
def escape_paths(paths, prefix=default_prefix):
|
||||
escaped_paths = []
|
||||
for path in paths:
|
||||
if path == prefix:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"invalid task field", path=path
|
||||
)
|
||||
escaped_paths.append(
|
||||
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
|
||||
if path.startswith(prefix)
|
||||
else path
|
||||
)
|
||||
return escaped_paths
|
||||
|
||||
projection = Task.get_projection(call.data)
|
||||
if projection:
|
||||
Task.set_projection(call.data, escape_paths(projection))
|
||||
|
||||
ordering = Task.get_ordering(call.data)
|
||||
if ordering:
|
||||
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
|
||||
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
|
||||
Task.set_ordering(call.data, escape_paths(ordering))
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
@@ -260,12 +270,15 @@ create_fields = {
|
||||
"input": None,
|
||||
"output_dest": None,
|
||||
"execution": None,
|
||||
"hyperparams": None,
|
||||
"configuration": None,
|
||||
"script": None,
|
||||
}
|
||||
|
||||
|
||||
def prepare_for_save(call: APICall, fields: dict):
|
||||
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
params_prepare_for_save(fields, previous_task=previous_task)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_fields:
|
||||
@@ -278,12 +291,6 @@ def prepare_for_save(call: APICall, fields: dict):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
parameters = safe_get(fields, "execution/parameters")
|
||||
if parameters is not None:
|
||||
# Escape keys to make them mongo-safe
|
||||
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
|
||||
dpath.set(fields, "execution/parameters", parameters)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
@@ -293,18 +300,15 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
||||
|
||||
conform_output_tags(call, tasks_data)
|
||||
|
||||
for task_data in tasks_data:
|
||||
parameters = safe_get(task_data, "execution/parameters")
|
||||
if parameters is not None:
|
||||
# Escape keys to make them mongo-safe
|
||||
parameters = {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
|
||||
}
|
||||
dpath.set(task_data, "execution/parameters", parameters)
|
||||
for data in tasks_data:
|
||||
params_unprepare_from_saved(
|
||||
fields=data,
|
||||
copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"),
|
||||
)
|
||||
|
||||
|
||||
def prepare_create_fields(
|
||||
call: APICall, valid_fields=None, output=None, previous_task: Task = None
|
||||
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
|
||||
):
|
||||
valid_fields = valid_fields if valid_fields is not None else create_fields
|
||||
t_fields = task_fields
|
||||
@@ -322,7 +326,7 @@ def prepare_create_fields(
|
||||
output = Output(destination=output_dest)
|
||||
fields["output"] = output
|
||||
|
||||
return prepare_for_save(call, fields)
|
||||
return prepare_for_save(call, fields, previous_task=previous_task)
|
||||
|
||||
|
||||
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
|
||||
@@ -354,9 +358,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Task, projects=projects
|
||||
)
|
||||
org_bll.reset_tags(company, Tags.Task, projects=projects)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -377,6 +379,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
||||
)
|
||||
def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
validate_tags(request.new_task_tags, request.new_task_system_tags)
|
||||
task = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
@@ -387,6 +390,8 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
project=request.new_task_project,
|
||||
tags=request.new_task_tags,
|
||||
system_tags=request.new_task_system_tags,
|
||||
hyperparams=request.new_hyperparams,
|
||||
configuration=request.new_configuration,
|
||||
execution_overrides=request.execution_overrides,
|
||||
validate_references=request.validate_references,
|
||||
)
|
||||
@@ -572,9 +577,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
if updated:
|
||||
new_project = fixed_fields.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, task.project]
|
||||
)
|
||||
_reset_cached_tags(company_id, projects=[new_project, task.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=task.project, fields=fixed_fields
|
||||
@@ -586,6 +589,100 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
|
||||
call.result.data = {
|
||||
"params": [{"task": task, **data} for task, data in tasks_params.items()]
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
|
||||
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
|
||||
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id, task_id=request.task, hyperparams=request.hyperparams
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
{"task": task, **data} for task, data in tasks_params.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
|
||||
)
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
{"task": task, **data} for task, data in tasks_params.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
|
||||
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
|
||||
def delete_configuration(
|
||||
call: APICall, company_id, request: DeleteConfigurationRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id, task_id=request.task, configuration=request.configuration
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.enqueue",
|
||||
request_data_model=EnqueueRequest,
|
||||
@@ -1004,3 +1101,19 @@ def add_or_update_artifacts(
|
||||
task_id=request.task, company_id=company_id, artifacts=request.artifacts
|
||||
)
|
||||
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
|
||||
|
||||
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
|
||||
@@ -46,10 +46,10 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
|
||||
|
||||
|
||||
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
|
||||
def register(call: APICall, company_id, req_model: RegisterRequest):
|
||||
worker = req_model.worker
|
||||
timeout = req_model.timeout
|
||||
queues = req_model.queues
|
||||
def register(call: APICall, company_id, request: RegisterRequest):
|
||||
worker = request.worker
|
||||
timeout = request.timeout
|
||||
queues = request.queues
|
||||
|
||||
if not timeout or timeout <= 0:
|
||||
raise bad_request.WorkerRegistrationFailed(
|
||||
@@ -63,6 +63,7 @@ def register(call: APICall, company_id, req_model: RegisterRequest):
|
||||
ip=call.real_ip,
|
||||
queues=queues,
|
||||
timeout=timeout,
|
||||
tags=request.tags,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,6 +79,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
|
||||
user_id=call.identity.user,
|
||||
ip=call.real_ip,
|
||||
report=request,
|
||||
tags=request.tags,
|
||||
)
|
||||
|
||||
|
||||
|
||||
28
server/sync.py
Normal file
28
server/sync.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from time import sleep
|
||||
|
||||
from redis_manager import redman
|
||||
|
||||
_redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def distributed_lock(name: str, timeout: int, max_wait: int = 0):
|
||||
"""
|
||||
Context manager that acquires a distributed lock on enter
|
||||
and releases it on exit. The has a ttl equal to timeout seconds
|
||||
If the lock can not be acquired for wait seconds (defaults to timeout * 2)
|
||||
then the exception is thrown
|
||||
"""
|
||||
lock_name = f"dist_lock_{name}"
|
||||
start = time.time()
|
||||
max_wait = max_wait or timeout * 2
|
||||
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
|
||||
sleep(1)
|
||||
if time.time() - start > max_wait:
|
||||
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_redis.delete(lock_name)
|
||||
@@ -1,3 +1,4 @@
|
||||
from apierrors.errors.bad_request import InvalidModelId
|
||||
from tests.automated import TestService
|
||||
|
||||
MODEL_CANNOT_BE_UPDATED_CODES = (400, 203)
|
||||
@@ -7,6 +8,9 @@ IN_PROGRESS = "in_progress"
|
||||
|
||||
|
||||
class TestModelsService(TestService):
|
||||
def setUp(self, version="2.9"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_publish_output_model_running_task(self):
|
||||
task_id, model_id = self._create_task_and_model()
|
||||
self._assert_model_ready(model_id, False)
|
||||
@@ -164,6 +168,58 @@ class TestModelsService(TestService):
|
||||
1000
|
||||
)
|
||||
|
||||
def test_get_frameworks(self):
|
||||
framework_1 = "Test framework 1"
|
||||
framework_2 = "Test framework 2"
|
||||
|
||||
# create model on top level
|
||||
self._create_model(name="framework model test", framework=framework_1)
|
||||
|
||||
# create model under a project as make it inherit its framework from the task
|
||||
project = self.create_temp("projects", name="Frameworks test", description="")
|
||||
task = self._create_task(project=project, execution=dict(framework=framework_2))
|
||||
self.api.models.update_for_task(
|
||||
task=task,
|
||||
name="framework output model test",
|
||||
uri="file:///b",
|
||||
iteration=999,
|
||||
)
|
||||
|
||||
# get all frameworks
|
||||
res = self.api.models.get_frameworks()
|
||||
self.assertTrue({framework_1, framework_2}.issubset(set(res.frameworks)))
|
||||
|
||||
# get frameworks under the project
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual([framework_2], res.frameworks)
|
||||
|
||||
# empty result
|
||||
self.api.tasks.delete(task=task, force=True)
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual([], res.frameworks)
|
||||
|
||||
def test_make_public(self):
|
||||
m1 = self._create_model(name="public model test")
|
||||
|
||||
# model with company_origin not set to the current company cannot be converted to private
|
||||
with self.api.raises(InvalidModelId):
|
||||
self.api.models.make_private(ids=[m1])
|
||||
|
||||
# public model can be retrieved but not updated
|
||||
res = self.api.models.make_public(ids=[m1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.models.get_all(id=[m1])
|
||||
self.assertEqual([m.id for m in res.models], [m1])
|
||||
with self.api.raises(InvalidModelId):
|
||||
self.api.models.update(model=m1, name="public model test change 1")
|
||||
|
||||
# task made private again and can be both retrieved and updated
|
||||
res = self.api.models.make_private(ids=[m1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.models.get_all(id=[m1])
|
||||
self.assertEqual([m.id for m in res.models], [m1])
|
||||
self.api.models.update(model=m1, name="public model test change 2")
|
||||
|
||||
def _assert_task_status(self, task_id, status):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
assert task.status == status
|
||||
@@ -178,24 +234,23 @@ class TestModelsService(TestService):
|
||||
def _assert_update_task_failure(self):
|
||||
return self.api.raises(TASK_CANNOT_BE_UPDATED_CODES)
|
||||
|
||||
def _create_model(self):
|
||||
model_id = self.create_temp(
|
||||
def _create_model(self, **kwargs):
|
||||
return self.create_temp(
|
||||
service="models",
|
||||
name='test',
|
||||
uri='file:///a',
|
||||
labels={}
|
||||
delete_params=dict(can_fail=True, force=True),
|
||||
name=kwargs.pop("name", 'test'),
|
||||
uri=kwargs.pop("name", 'file:///a'),
|
||||
labels=kwargs.pop("labels", {}),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
|
||||
|
||||
return model_id
|
||||
|
||||
def _create_task(self):
|
||||
def _create_task(self, **kwargs):
|
||||
task_id = self.create_temp(
|
||||
service="tasks",
|
||||
type='testing',
|
||||
name='server-test',
|
||||
input=dict(view={}),
|
||||
type=kwargs.pop("type", 'testing'),
|
||||
name=kwargs.pop("name", 'server-test'),
|
||||
input=kwargs.pop("input", dict(view={})),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return task_id
|
||||
|
||||
@@ -6,14 +6,86 @@ log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestProjection(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.6")
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test projection",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def _temp_project(self):
|
||||
return self.create_temp(
|
||||
"projects",
|
||||
name="Test projection",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
|
||||
def test_overlapping_fields(self):
|
||||
message = "task started"
|
||||
task_id = self.create_temp(
|
||||
"tasks", name="test", type="testing", input=dict(view=dict())
|
||||
)
|
||||
task_id = self._temp_task()
|
||||
self.api.tasks.started(task=task_id, status_message=message)
|
||||
task = self.api.tasks.get_all_ex(
|
||||
id=[task_id], only_fields=["status", "status_message"]
|
||||
).tasks[0]
|
||||
assert task["status"] == TaskStatus.in_progress
|
||||
assert task["status_message"] == message
|
||||
|
||||
def test_task_projection(self):
|
||||
project = self._temp_project()
|
||||
task1 = self._temp_task(project=project)
|
||||
task2 = self._temp_task(project=project)
|
||||
self.api.tasks.started(task=task2, status_message="Started")
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
project=[project],
|
||||
only_fields=[
|
||||
"system_tags",
|
||||
"company",
|
||||
"type",
|
||||
"name",
|
||||
"tags",
|
||||
"status",
|
||||
"project.name",
|
||||
"user.name",
|
||||
"started",
|
||||
"last_update",
|
||||
"last_iteration",
|
||||
"comment",
|
||||
],
|
||||
order_by=["-started"],
|
||||
page=0,
|
||||
page_size=15,
|
||||
system_tags=["-archived"],
|
||||
type=[
|
||||
"__$not",
|
||||
"annotation_manual",
|
||||
"__$not",
|
||||
"annotation",
|
||||
"__$not",
|
||||
"dataset_import",
|
||||
],
|
||||
).tasks
|
||||
self.assertEqual([task2, task1], [t.id for t in res])
|
||||
self.assertEqual("Test projection", res[0].project.name)
|
||||
|
||||
def test_exclude_projection(self):
|
||||
task_id = self._temp_task()
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=[task_id]
|
||||
).tasks[0]
|
||||
self.assertEqual("test projection", res.name)
|
||||
|
||||
task = self.api.tasks.get_all_ex(
|
||||
id=[task_id],
|
||||
only_fields=["-name"]
|
||||
).tasks[0]
|
||||
self.assertFalse("name" in task)
|
||||
self.assertEqual("testing", res.type)
|
||||
|
||||
34
server/tests/automated/test_projects_edit.py
Normal file
34
server/tests/automated/test_projects_edit.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from apierrors.errors.bad_request import InvalidProjectId
|
||||
from apierrors.errors.forbidden import NoWritePermission
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestProjectsEdit(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.9")
|
||||
|
||||
def test_make_public(self):
|
||||
p1 = self.create_temp("projects", name="Test public", description="test")
|
||||
|
||||
# project with company_origin not set to the current company cannot be converted to private
|
||||
with self.api.raises(InvalidProjectId):
|
||||
self.api.projects.make_private(ids=[p1])
|
||||
|
||||
# public project can be retrieved but not updated
|
||||
res = self.api.projects.make_public(ids=[p1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.projects.get_all(id=[p1])
|
||||
self.assertEqual([p.id for p in res.projects], [p1])
|
||||
with self.api.raises(NoWritePermission):
|
||||
self.api.projects.update(project=p1, name="Test public change 1")
|
||||
|
||||
# task made private again and can be both retrieved and updated
|
||||
res = self.api.projects.make_private(ids=[p1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.projects.get_all(id=[p1])
|
||||
self.assertEqual([p.id for p in res.projects], [p1])
|
||||
self.api.projects.update(project=p1, name="Test public change 2")
|
||||
@@ -6,7 +6,7 @@ import operator
|
||||
import unittest
|
||||
from functools import partial
|
||||
from statistics import mean
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
@@ -16,7 +16,7 @@ from tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
def setUp(self, version="2.7"):
|
||||
def setUp(self, version="2.9"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
@@ -213,7 +213,6 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(len(res.events), 1)
|
||||
|
||||
def test_task_logs(self):
|
||||
# this test will fail until the new api is uncommented
|
||||
task = self._temp_task()
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
events = [
|
||||
@@ -229,32 +228,29 @@ class TestTaskEvents(TestService):
|
||||
self.send_batch(events)
|
||||
|
||||
# test forward navigation
|
||||
scroll_id = None
|
||||
for page in range(3):
|
||||
scroll_id = self._assert_log_events(
|
||||
task=task, scroll_id=scroll_id, expected_page=page
|
||||
ftime, ltime = None, None
|
||||
for page in range(2):
|
||||
ftime, ltime = self._assert_log_events(
|
||||
task=task, timestamp=ltime, expected_page=page
|
||||
)
|
||||
|
||||
# test backwards navigation
|
||||
scroll_id = self._assert_log_events(
|
||||
task=task, scroll_id=scroll_id, navigate_earlier=False
|
||||
)
|
||||
self._assert_log_events(task=task, timestamp=ftime, navigate_earlier=False)
|
||||
|
||||
# refresh
|
||||
self._assert_log_events(task=task, scroll_id=scroll_id)
|
||||
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
|
||||
# test order
|
||||
self._assert_log_events(task=task, order="asc")
|
||||
|
||||
def _assert_log_events(
|
||||
self,
|
||||
task,
|
||||
scroll_id,
|
||||
batch_size: int = 5,
|
||||
timestamp: Optional[int] = None,
|
||||
expected_total: int = 10,
|
||||
expected_page: int = 0,
|
||||
**extra_params,
|
||||
):
|
||||
) -> Tuple[int, int]:
|
||||
res = self.api.events.get_task_log(
|
||||
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
|
||||
task=task, batch_size=batch_size, from_timestamp=timestamp, **extra_params,
|
||||
)
|
||||
self.assertEqual(res.total, expected_total)
|
||||
expected_events = max(
|
||||
@@ -266,7 +262,10 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(len(res.events), unique_events)
|
||||
if res.events:
|
||||
cmp_operator = operator.ge
|
||||
if not extra_params.get("navigate_earlier", True):
|
||||
if (
|
||||
not extra_params.get("navigate_earlier", True)
|
||||
or extra_params.get("order", None) == "asc"
|
||||
):
|
||||
cmp_operator = operator.le
|
||||
self.assertTrue(
|
||||
all(
|
||||
@@ -274,7 +273,12 @@ class TestTaskEvents(TestService):
|
||||
for first, second in zip(res.events, res.events[1:])
|
||||
)
|
||||
)
|
||||
return res.scroll_id
|
||||
|
||||
return (
|
||||
(res.events[0].timestamp, res.events[-1].timestamp)
|
||||
if res.events
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
|
||||
281
server/tests/automated/test_task_hyperparams.py
Normal file
281
server/tests/automated/test_task_hyperparams.py
Normal file
@@ -0,0 +1,281 @@
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, List, Tuple
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apierrors.errors.bad_request import InvalidTaskStatus
|
||||
from tests.api_client import APIClient
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestTasksHyperparams(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.9")
|
||||
|
||||
def new_task(self, **kwargs) -> Tuple[str, str]:
|
||||
if "project" not in kwargs:
|
||||
kwargs["project"] = self.create_temp(
|
||||
"projects",
|
||||
name="Test hyperparams",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test hyperparams",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs), kwargs["project"]
|
||||
|
||||
def test_hyperparams(self):
|
||||
legacy_params = {"legacy$1": "val1", "legacy2/name": "val2"}
|
||||
new_params = [
|
||||
dict(section="1/1", name="param1/1", type="type1", value="10"),
|
||||
dict(section="1/1", name="param2", type="type1", value="20"),
|
||||
dict(section="2", name="param2", type="type2", value="xxx"),
|
||||
]
|
||||
new_params_dict = self._param_dict_from_list(new_params)
|
||||
task, project = self.new_task(
|
||||
execution={"parameters": legacy_params}, hyperparams=new_params_dict,
|
||||
)
|
||||
# both params and hyper params are set correctly
|
||||
old_params = self._new_params_from_legacy(legacy_params)
|
||||
params_dict = new_params_dict.copy()
|
||||
params_dict["Args"] = {p["name"]: p for p in old_params}
|
||||
res = self.api.tasks.get_by_id(task=task).task
|
||||
self.assertEqual(params_dict, res.hyperparams)
|
||||
|
||||
# returned as one list with params in the _legacy section
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(new_params + old_params, res.hyperparams)
|
||||
|
||||
# replace section
|
||||
replace_params = [
|
||||
dict(section="1/1", name="param1", type="type1", value="40"),
|
||||
dict(section="2", name="param5", type="type1", value="11"),
|
||||
]
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=replace_params, replace_hyperparams="section"
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(replace_params + old_params, res.hyperparams)
|
||||
|
||||
# replace all
|
||||
replace_params = [
|
||||
dict(section="1/1", name="param1/1", type="type1", value="30"),
|
||||
dict(section="Args", name="legacy$1", value="123", type="legacy"),
|
||||
]
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=replace_params, replace_hyperparams="all"
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(replace_params, res.hyperparams)
|
||||
|
||||
# add and update
|
||||
self.api.tasks.edit_hyper_params(task=task, hyperparams=new_params + old_params)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(new_params + old_params, res.hyperparams)
|
||||
|
||||
# delete
|
||||
new_to_delete = self._get_param_keys(new_params[1:])
|
||||
old_to_delete = self._get_param_keys(old_params[:1])
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=new_to_delete + old_to_delete
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(new_params[:1] + old_params[1:], res.hyperparams)
|
||||
|
||||
# delete section
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=[{"section": "1/1"}, {"section": "2"}]
|
||||
)
|
||||
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
|
||||
self.assertEqual(old_params[1:], res.hyperparams)
|
||||
|
||||
# project hyperparams
|
||||
res = self.api.projects.get_hyper_parameters(project=project)
|
||||
self.assertEqual(
|
||||
[
|
||||
{k: v for k, v in p.items() if k in ("section", "name")}
|
||||
for p in old_params[1:]
|
||||
],
|
||||
res.parameters,
|
||||
)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id
|
||||
try:
|
||||
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
|
||||
self.assertEqual(new_params, res.hyperparams)
|
||||
finally:
|
||||
self.api.tasks.delete(task=new_task, force=True)
|
||||
|
||||
# editing of started task
|
||||
self.api.tasks.started(task=task)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=[dict(section="test", name="x", value="123")]
|
||||
)
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=task, hyperparams=[dict(section="properties", name="x", value="123")]
|
||||
)
|
||||
self.api.tasks.delete_hyper_params(
|
||||
task=task, hyperparams=[dict(section="Properties")]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_param_keys(params: Sequence[dict]) -> List[dict]:
|
||||
return [{k: p[k] for k in ("name", "section")} for p in params]
|
||||
|
||||
@staticmethod
|
||||
def _new_params_from_legacy(legacy: dict) -> List[dict]:
|
||||
return [
|
||||
dict(section="Args", name=k, value=str(v), type="legacy")
|
||||
if not k.startswith("TF_DEFINE/")
|
||||
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
|
||||
for k, v in legacy.items()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _param_dict_from_list(params: Sequence[dict]) -> dict:
|
||||
return {
|
||||
k: {v["name"]: v for v in values}
|
||||
for k, values in iterutils.bucketize(
|
||||
params, key=itemgetter("section")
|
||||
).items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _config_dict_from_list(config: Sequence[dict]) -> dict:
|
||||
return {c["name"]: c for c in config}
|
||||
|
||||
def test_configuration(self):
|
||||
legacy_config = {"design": "hello"}
|
||||
new_config = [
|
||||
dict(name="param$1", type="type1", value="10"),
|
||||
dict(name="param/2", type="type1", value="20"),
|
||||
]
|
||||
new_config_dict = self._config_dict_from_list(new_config)
|
||||
task, _ = self.new_task(
|
||||
execution={"model_desc": legacy_config}, configuration=new_config_dict
|
||||
)
|
||||
|
||||
# both params and hyper params are set correctly
|
||||
old_config = self._new_config_from_legacy(legacy_config)
|
||||
config_dict = new_config_dict.copy()
|
||||
config_dict["design"] = old_config[0]
|
||||
res = self.api.tasks.get_by_id(task=task).task
|
||||
self.assertEqual(config_dict, res.configuration)
|
||||
|
||||
# returned as one list
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config, res.configuration)
|
||||
|
||||
# names
|
||||
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
|
||||
self.assertEqual(task, res.task)
|
||||
self.assertEqual(["design", "param$1", "param/2"], res.names)
|
||||
|
||||
# returned as one list with names filtering
|
||||
res = self.api.tasks.get_configurations(
|
||||
tasks=[task], names=[new_config[1]["name"]]
|
||||
).configurations[0]
|
||||
self.assertEqual([new_config[1]], res.configuration)
|
||||
|
||||
# replace all
|
||||
replace_configs = [
|
||||
dict(name="design", value="123", type="legacy"),
|
||||
dict(name="param/2", type="type1", value="30"),
|
||||
]
|
||||
self.api.tasks.edit_configuration(
|
||||
task=task, configuration=replace_configs, replace_configuration=True
|
||||
)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(replace_configs, res.configuration)
|
||||
|
||||
# add and update
|
||||
self.api.tasks.edit_configuration(
|
||||
task=task, configuration=new_config + old_config
|
||||
)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config, res.configuration)
|
||||
|
||||
# delete
|
||||
new_to_delete = self._get_config_keys(new_config[1:])
|
||||
res = self.api.tasks.delete_configuration(
|
||||
task=task, configuration=new_to_delete
|
||||
)
|
||||
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
|
||||
self.assertEqual(old_config + new_config[:1], res.configuration)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id
|
||||
try:
|
||||
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
|
||||
self.assertEqual(new_config, res.configuration)
|
||||
finally:
|
||||
self.api.tasks.delete(task=new_task, force=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_config_keys(config: Sequence[dict]) -> List[dict]:
|
||||
return [c["name"] for c in config]
|
||||
|
||||
@staticmethod
|
||||
def _new_config_from_legacy(legacy: dict) -> List[dict]:
|
||||
return [dict(name=k, value=str(v), type="legacy") for k, v in legacy.items()]
|
||||
|
||||
def test_hyperparams_projection(self):
|
||||
legacy_param = {"legacy.1": "val1"}
|
||||
new_params1 = [
|
||||
dict(section="sec.tion1", name="param1", type="type1", value="10")
|
||||
]
|
||||
new_params_dict1 = self._param_dict_from_list(new_params1)
|
||||
task1, project = self.new_task(
|
||||
execution={"parameters": legacy_param}, hyperparams=new_params_dict1,
|
||||
)
|
||||
|
||||
new_params2 = [
|
||||
dict(section="sec.tion1", name="param1", type="type1", value="20")
|
||||
]
|
||||
new_params_dict2 = self._param_dict_from_list(new_params2)
|
||||
task2, _ = self.new_task(hyperparams=new_params_dict2, project=project)
|
||||
|
||||
old_params = self._new_params_from_legacy(legacy_param)
|
||||
params_dict = new_params_dict1.copy()
|
||||
params_dict["Args"] = {p["name"]: p for p in old_params}
|
||||
res = self.api.tasks.get_all_ex(id=[task1], only_fields=["hyperparams"]).tasks[
|
||||
0
|
||||
]
|
||||
self.assertEqual(params_dict, res.hyperparams)
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
project=[project],
|
||||
only_fields=["hyperparams.sec%2Etion1"],
|
||||
order_by=["-hyperparams.sec%2Etion1"],
|
||||
).tasks[0]
|
||||
self.assertEqual(new_params_dict2, res.hyperparams)
|
||||
|
||||
def test_old_api(self):
|
||||
legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"}
|
||||
legacy_config = {"design": "hello"}
|
||||
task_id, _ = self.new_task(
|
||||
execution={"parameters": legacy_params, "model_desc": legacy_config}
|
||||
)
|
||||
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
|
||||
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
|
||||
|
||||
old_api = APIClient(base_url="http://localhost:8008/v2.8")
|
||||
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(legacy_params, task.execution.parameters)
|
||||
self.assertEqual(legacy_config, task.execution.model_desc)
|
||||
self.assertEqual(params, task.hyperparams)
|
||||
self.assertEqual(config, task.configuration)
|
||||
|
||||
modified_params = {"legacy.2": "val2"}
|
||||
modified_config = {"design": "by"}
|
||||
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
|
||||
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(modified_params, task.execution.parameters)
|
||||
self.assertEqual(modified_config, task.execution.model_desc)
|
||||
@@ -5,7 +5,6 @@ log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestTasksDiff(TestService):
|
||||
|
||||
def setUp(self, version="2.0"):
|
||||
super(TestTasksDiff, self).setUp(version=version)
|
||||
|
||||
@@ -17,7 +16,14 @@ class TestTasksDiff(TestService):
|
||||
def _compare_script(self, task_id, script):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
if not script:
|
||||
self.assertFalse(task.get("script", None))
|
||||
self.assertTrue(
|
||||
task.get(
|
||||
"script",
|
||||
dict(
|
||||
binary="python", repository="", entry_point="", requirements={}
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
for key, value in script.items():
|
||||
self.assertEqual(task.script[key], value)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from apierrors.errors.bad_request import InvalidModelId, ValidationError
|
||||
from apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
|
||||
from apierrors.errors.forbidden import NoWritePermission
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
@@ -8,7 +9,7 @@ log = config.logger(__file__)
|
||||
|
||||
class TestTasksEdit(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version=2.5)
|
||||
super().setUp(version="2.9")
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
@@ -113,7 +114,7 @@ class TestTasksEdit(TestService):
|
||||
self.assertEqual(new_task.status, "created")
|
||||
self.assertEqual(new_task.script, script)
|
||||
self.assertEqual(new_task.parent, task)
|
||||
self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
||||
# self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
||||
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
|
||||
self.assertEqual(new_task.system_tags, [])
|
||||
|
||||
@@ -145,3 +146,28 @@ class TestTasksEdit(TestService):
|
||||
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
|
||||
)
|
||||
return new_task
|
||||
|
||||
def test_make_public(self):
|
||||
task = self.new_task()
|
||||
|
||||
# task is created as private and can be updated
|
||||
self.api.tasks.started(task=task)
|
||||
|
||||
# task with company_origin not set to the current company cannot be converted to private
|
||||
with self.api.raises(InvalidTaskId):
|
||||
self.api.tasks.make_private(ids=[task])
|
||||
|
||||
# public task can be retrieved but not updated
|
||||
res = self.api.tasks.make_public(ids=[task])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.tasks.get_all_ex(id=[task])
|
||||
self.assertEqual([t.id for t in res.tasks], [task])
|
||||
with self.api.raises(NoWritePermission):
|
||||
self.api.tasks.stopped(task=task)
|
||||
|
||||
# task made private again and can be both retrieved and updated
|
||||
res = self.api.tasks.make_private(ids=[task])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.tasks.get_all_ex(id=[task])
|
||||
self.assertEqual([t.id for t in res.tasks], [task])
|
||||
self.api.tasks.stopped(task=task)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
@@ -1,12 +1,2 @@
|
||||
import dpath
|
||||
|
||||
|
||||
def strict_map(*args, **kwargs):
|
||||
return list(map(*args, **kwargs))
|
||||
|
||||
|
||||
def safe_get(obj, glob, default=None):
|
||||
try:
|
||||
return dpath.get(obj, glob)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
46
server/utilities/parameter_key_escaper.py
Normal file
46
server/utilities/parameter_key_escaper.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from boltons.dictutils import OneToOne
|
||||
|
||||
from apierrors import errors
|
||||
|
||||
|
||||
class ParameterKeyEscaper:
|
||||
"""
|
||||
Makes the fields name ready for use with MongoDB and Mongoengine
|
||||
. and $ are replaced with their codes
|
||||
__ and leading _ are escaped
|
||||
Since % is used as an escape character the % is also escaped
|
||||
"""
|
||||
|
||||
_mapping = OneToOne({".": "%2E", "$": "%24", "__": "%_%_"})
|
||||
|
||||
@classmethod
|
||||
def escape(cls, value):
|
||||
""" Quote a parameter key """
|
||||
if value is None:
|
||||
raise errors.bad_request.ValidationError("Key cannot be empty")
|
||||
|
||||
value = value.strip().replace("%", "%%")
|
||||
|
||||
for c, r in cls._mapping.items():
|
||||
value = value.replace(c, r)
|
||||
|
||||
if value.startswith("_"):
|
||||
value = "%_" + value[1:]
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _unescape(cls, value):
|
||||
for c, r in cls._mapping.inv.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def unescape(cls, value):
|
||||
""" Unquote a quoted parameter key """
|
||||
value = "%".join(map(cls._unescape, value.split("%%")))
|
||||
|
||||
if value.startswith("%_"):
|
||||
value = "_" + value[2:]
|
||||
|
||||
return value
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.15.1"
|
||||
__version__ = "0.16.1"
|
||||
|
||||
Reference in New Issue
Block a user