mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11d76e7d8c | ||
|
|
e76c0fbc63 | ||
|
|
fdc9956da3 | ||
|
|
f4addaa653 | ||
|
|
667964cc82 | ||
|
|
e1309e30b7 | ||
|
|
9403942ef7 | ||
|
|
84a75d9e70 | ||
|
|
c85ab66ae6 | ||
|
|
bf7f0f646b | ||
|
|
dcdf2a3d58 | ||
|
|
f8d8fc40a6 | ||
|
|
45d434a123 | ||
|
|
1834abe5bc | ||
|
|
d6321588f3 | ||
|
|
c17b10ff1d | ||
|
|
b125a56f86 | ||
|
|
c43ce3a17b | ||
|
|
b0b09616a8 | ||
|
|
ede5586ccc | ||
|
|
a1dcdffa53 | ||
|
|
35a11db58e | ||
|
|
d9bdebefc7 | ||
|
|
f29884f05a | ||
|
|
0f72d662f8 | ||
|
|
6202219034 | ||
|
|
bb3218f65d | ||
|
|
cbcaa7c789 | ||
|
|
427322a424 | ||
|
|
0e7d7d36a9 | ||
|
|
06032a6d66 | ||
|
|
b48f4eb2eb | ||
|
|
383b2666c4 | ||
|
|
50c373cf0d | ||
|
|
394a9de5fa | ||
|
|
fb5c06e9c3 | ||
|
|
1a9bbc9420 | ||
|
|
294da32401 | ||
|
|
7f00672010 | ||
|
|
99bf89a360 | ||
|
|
6c8508eb7f | ||
|
|
69714d5b5c | ||
|
|
f9516ec7d3 | ||
|
|
6fdde93dee | ||
|
|
7afc71ec91 | ||
|
|
4595117d91 | ||
|
|
8630cc1021 | ||
|
|
135885b609 | ||
|
|
eb0865662c | ||
|
|
b7b94e7ae5 | ||
|
|
72be8bee19 | ||
|
|
0722b20c1c | ||
|
|
a392a0e6ff | ||
|
|
e22fa2f478 | ||
|
|
8b49c1ac06 |
@@ -7,6 +7,8 @@
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://img.shields.io/badge/status-beta-yellow.svg)
|
||||
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
@@ -61,6 +63,7 @@ For example, to see if port `8080` is in use:
|
||||
Launch **trains-server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
|
||||
- Pre-built [GCP Custom Image](https://github.com/allegroai/trains-server/blob/master/docs/install_gcp.md)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
|
||||
@@ -22,6 +22,8 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
|
||||
@@ -22,6 +22,8 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
@@ -113,6 +115,36 @@ services:
|
||||
ports:
|
||||
- "8080:80"
|
||||
|
||||
agent-services:
|
||||
container_name: trains-agent-services
|
||||
image: allegroai/trains-agent-services:latest
|
||||
restart: unless-stopped
|
||||
privileged: true
|
||||
environment:
|
||||
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
|
||||
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
|
||||
TRAINS_API_HOST: ${TRAINS_API_HOST:-}
|
||||
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
|
||||
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
|
||||
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
|
||||
TRAINS_AGENT_GIT_USER: ${TRAINS_AGENT_GIT_USER}
|
||||
TRAINS_AGENT_GIT_PASS: ${TRAINS_AGENT_GIT_PASS}
|
||||
TRAINS_AGENT_UPDATE_VERSION: ${TRAINS_AGENT_UPDATE_VERSION:->=0.15.0}
|
||||
TRAINS_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
|
||||
AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION:-}
|
||||
AZURE_STORAGE_ACCOUNT: ${AZURE_STORAGE_ACCOUNT:-}
|
||||
AZURE_STORAGE_KEY: ${AZURE_STORAGE_KEY:-}
|
||||
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
|
||||
TRAINS_WORKER_ID: "trains-services"
|
||||
TRAINS_AGENT_DOCKER_HOST_MOUNT: "/opt/trains/agent:/root/.trains"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /opt/trains/agent:/root/.trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
@@ -26,7 +26,7 @@ The minimum recommended amount of RAM is 8GB. For example, **t3.large** or **t3a
|
||||
|
||||
To upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
### Upgrading AMIs to v0.12
|
||||
### Note on upgrading AMIs to v0.12
|
||||
|
||||
This upgrade includes the automatically updated AMI in Version 0.12. It also includes an additional REDIS docker to the **trains-server** setup.
|
||||
|
||||
@@ -50,44 +50,119 @@ To upgrade the AMI:
|
||||
|
||||
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
|
||||
|
||||
### Latest version AMI - v0.13.0 (auto update)<a name="autoupdate"></a>
|
||||
### Latest version AMI - v0.15.0 (auto update)<a name="autoupdate"></a>
|
||||
|
||||
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
|
||||
|
||||
* **eu-north-1** : ami-003024b7b575d3f2a
|
||||
* **ap-south-1** : ami-0d784c7ac2ab4cc72
|
||||
* **eu-west-3** : ami-091d745be445b69db
|
||||
* **eu-west-2** : ami-0a4ebf5d45c672411
|
||||
* **eu-west-1** : ami-021e3421c50d1482c
|
||||
* **ap-northeast-2** : ami-0d0a25ec610d6d122
|
||||
* **ap-northeast-1** : ami-01d896f9ae5d87890
|
||||
* **sa-east-1** : ami-09bcb93835428a412
|
||||
* **ca-central-1** : ami-077fa58c9f73690c7
|
||||
* **ap-southeast-1** : ami-046fe4832b077b517
|
||||
* **ap-southeast-2** : ami-0ab9acb41f8abbba7
|
||||
* **eu-central-1** : ami-079be664aae12db00
|
||||
* **us-east-2** : ami-0d48555f80cb7993a
|
||||
* **us-west-1** : ami-0ed85ab91a7bb5a8a
|
||||
* **us-west-2** : ami-0b4fe4ca18e9b1227
|
||||
* **us-east-1** : ami-043b95dd034e581e6
|
||||
* **eu-north-1** : ami-0a05eb5b384a84609
|
||||
* **ap-south-1** : ami-00f190b50e60b1eb5
|
||||
* **eu-west-3** : ami-044fad585e1d1798e
|
||||
* **eu-west-2** : ami-04ab930416a4af8c5
|
||||
* **eu-west-1** : ami-00c022f333417e78e
|
||||
* **ap-northeast-2** : ami-0c436e94f461a9a22
|
||||
* **ap-northeast-1** : ami-018e761ad0009d5d4
|
||||
* **sa-east-1** : ami-0b6c0e8e93b6ebbdd
|
||||
* **ca-central-1** : ami-0cf12aab70c14237d
|
||||
* **ap-southeast-1** : ami-0fe7840b9bde05581
|
||||
* **ap-southeast-2** : ami-00f230e86e1afda91
|
||||
* **eu-central-1** : ami-0635d13b79f76e04f
|
||||
* **us-east-2** : ami-0b323078d0206db0e
|
||||
* **us-west-1** : ami-07fdc1d461906f957
|
||||
* **us-west-2** : ami-0a5cac167c3ebdedb
|
||||
* **us-east-1** : ami-0d03956bea3aa5a44
|
||||
|
||||
### v0.15.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0475a5068d615769b
|
||||
* **ap-south-1** : ami-00c7e642badaa2ebf
|
||||
* **eu-west-3** : ami-0655f769c28843e25
|
||||
* **eu-west-2** : ami-04d82f48f09e2b846
|
||||
* **eu-west-1** : ami-07a2aab2dc7b4ec5f
|
||||
* **ap-northeast-2** : ami-0257ab220a8bc7a52
|
||||
* **ap-northeast-1** : ami-0c4900af758b91dde
|
||||
* **sa-east-1** : ami-021f758a4a21d5725
|
||||
* **ca-central-1** : ami-0ce9703b3b47cfe70
|
||||
* **ap-southeast-1** : ami-0b38689fdb8f71b74
|
||||
* **ap-southeast-2** : ami-0c2b3a171e7ae4b00
|
||||
* **eu-central-1** : ami-0fdd3420d6e6b4a1f
|
||||
* **us-east-2** : ami-0288e9654da36ed1c
|
||||
* **us-west-1** : ami-0f1d6ee0b73fe9ca2
|
||||
* **us-west-2** : ami-025f0c5bfeacbf390
|
||||
* **us-east-1** : ami-0b17b0bfa8b91f805
|
||||
|
||||
### v0.14.2 (static update)
|
||||
|
||||
* **eu-north-1** : ami-006d491e9e8869248
|
||||
* **ap-south-1** : ami-0e55ec221687f98e7
|
||||
* **eu-west-3** : ami-06ad9cf3c05c83e91
|
||||
* **eu-west-2** : ami-0d05839268e748cff
|
||||
* **eu-west-1** : ami-0d14c297789ce0d7a
|
||||
* **ap-northeast-2** : ami-0d7fd775f0e76cc6f
|
||||
* **ap-northeast-1** : ami-0c0a6e1daeb3f7a9c
|
||||
* **sa-east-1** : ami-01e0c5e30e94ec887
|
||||
* **ca-central-1** : ami-07a31896832734897
|
||||
* **ap-southeast-1** : ami-0886d5b2d4b7fccd5
|
||||
* **ap-southeast-2** : ami-0397d5a2db3c356fe
|
||||
* **eu-central-1** : ami-0629f26eea22f5c17
|
||||
* **us-east-2** : ami-0499c3d7bb45a1a6e
|
||||
* **us-west-1** : ami-02fa8a961a4daf9f0
|
||||
* **us-west-2** : ami-05c711cfab4342468
|
||||
* **us-east-1** : ami-0b97d99a08012c726
|
||||
|
||||
### v0.14.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-036defe1885dced2e
|
||||
* **ap-south-1** : ami-0b403aa1da6a5dc17
|
||||
* **eu-west-3** : ami-0d30c2d330d1255c4
|
||||
* **eu-west-2** : ami-06f0e8d075e50a029
|
||||
* **eu-west-1** : ami-0da721d874f282b6d
|
||||
* **ap-northeast-2** : ami-03bffe94675dd5f8c
|
||||
* **ap-northeast-1** : ami-0f96520d646423673
|
||||
* **sa-east-1** : ami-0c2f706a3b7d97282
|
||||
* **ca-central-1** : ami-0da74525dcfd74e32
|
||||
* **ap-southeast-1** : ami-066368a21cf6d232b
|
||||
* **ap-southeast-2** : ami-0bfd09170067f7318
|
||||
* **eu-central-1** : ami-06aa99b1c41492986
|
||||
* **us-east-2** : ami-065c1880f59d03272
|
||||
* **us-west-1** : ami-0b7f6b896f5058eba
|
||||
* **us-west-2** : ami-0041e10ca68eef29a
|
||||
* **us-east-1** : ami-0b7125e4305bbd7eb
|
||||
|
||||
### v0.14.0 (static update)
|
||||
* **eu-north-1** : ami-02de71586ec496e38
|
||||
* **ap-south-1** : ami-074b03849b51852e5
|
||||
* **eu-west-3** : ami-022c388835e0eeb03
|
||||
* **eu-west-2** : ami-0a151c236c6b27707
|
||||
* **eu-west-1** : ami-06de69b06b4e73312
|
||||
* **ap-northeast-2** : ami-0ee821b72d9f669b1
|
||||
* **ap-northeast-1** : ami-03687ae215e64e100
|
||||
* **sa-east-1** : ami-01eb83364b7f667af
|
||||
* **ca-central-1** : ami-02e9b35f9c90377e6
|
||||
* **ap-southeast-1** : ami-0d3ab5ab0048fea51
|
||||
* **ap-southeast-2** : ami-0bd39d908fe3a9e06
|
||||
* **eu-central-1** : ami-0b8638701311b35c4
|
||||
* **us-east-2** : ami-02ff039693fc3a614
|
||||
* **us-west-1** : ami-08634f7dfb608a9a7
|
||||
* **us-west-2** : ami-034d693ef742b9333
|
||||
* **us-east-1** : ami-0b828b05c323dde7f
|
||||
|
||||
### v0.13.0 (static update)
|
||||
* **eu-north-1** : ami-0e26c3af1663428dc
|
||||
* **ap-south-1** : ami-07451eb44f51380a8
|
||||
* **eu-west-3** : ami-0108e506c6e0be8d8
|
||||
* **eu-west-2** : ami-0fc1fdbc7699f0dde
|
||||
* **eu-west-1** : ami-0efbf8d2f580a9cee
|
||||
* **ap-northeast-2** : ami-08f0bbd7e08d0603e
|
||||
* **ap-northeast-1** : ami-024522bea34dbe3ce
|
||||
* **sa-east-1** : ami-0fe5b6e0ddc1553d9
|
||||
* **ca-central-1** : ami-0037c26178a584ade
|
||||
* **ap-southeast-1** : ami-049dbcc0f0a6dba20
|
||||
* **ap-southeast-2** : ami-02d1ce8d31c27f187
|
||||
* **eu-central-1** : ami-0550b14b40371182a
|
||||
* **us-east-2** : ami-040a1f16ceda8f255
|
||||
* **us-west-1** : ami-003b5673c08d68cdb
|
||||
* **us-west-2** : ami-0fec951d8043da62d
|
||||
* **us-east-1** : ami-049694de0137fdea4
|
||||
* **eu-north-1** : ami-0d9c74a015e7510d8
|
||||
* **ap-south-1** : ami-02acd6dd0659bb5c1
|
||||
* **eu-west-3** : ami-0f0cc5cb6d9afd194
|
||||
* **eu-west-2** : ami-0298fdc0860206ed9
|
||||
* **eu-west-1** : ami-0cdc072e528401d5e
|
||||
* **ap-northeast-2** : ami-0055579cc95b0e53e
|
||||
* **ap-northeast-1** : ami-0ced7becb9b83b5d0
|
||||
* **sa-east-1** : ami-033345d0f16a1b5e4
|
||||
* **ca-central-1** : ami-06c63b05aed47ae67
|
||||
* **ap-southeast-1** : ami-09f0355f367f30602
|
||||
* **ap-southeast-2** : ami-0bd2314163ce0fba0
|
||||
* **eu-central-1** : ami-05fbae957df63e366
|
||||
* **us-east-2** : ami-050c51b5b4074d3fc
|
||||
* **us-west-1** : ami-06ad513073d4e5a19
|
||||
* **us-west-2** : ami-0c96e1361d1d4ca94
|
||||
* **us-east-1** : ami-07b669040d1eea213
|
||||
|
||||
### v0.12.1 (static update)
|
||||
* **eu-north-1** : ami-003118a8103286d84
|
||||
|
||||
58
docs/install_gcp.md
Normal file
58
docs/install_gcp.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Deploying Trains Server on Google Cloud Platform
|
||||
|
||||
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
|
||||
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
|
||||
|
||||
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
|
||||
The service port numbers on our Trains Server GCP Custom Image are:
|
||||
|
||||
- Web application: `8080`
|
||||
- API Server: `8008`
|
||||
- File Server: `8081`
|
||||
|
||||
The persistent storage configuration:
|
||||
|
||||
- MongoDB: `/opt/trains/data/mongo/`
|
||||
- ElasticSearch: `/opt/trains/data/elastic/`
|
||||
- File Server: `/mnt/fileserver/`
|
||||
|
||||
For examples and use cases, check the [Trains usage examples](https://github.com/allegroai/trains/blob/master/docs/trains_examples.md).
|
||||
|
||||
## Importing the Custom Image to your GCP account
|
||||
|
||||
In order to launch an instance using the Trains Server GCP Custom Image, you'll need to import the image to your custom images list.
|
||||
|
||||
**Note:** there's **no need** to upload the image file to Google Cloud Storage - we already provide links to image files stored in Google Storage
|
||||
|
||||
To import the image to your custom images list:
|
||||
1. In the Cloud Console, go to the [Images](https://console.cloud.google.com/compute/images) page.
|
||||
1. At the top of the page, click **Create image**.
|
||||
1. In the **Name** field, specify a unique name for the image.
|
||||
1. Optionally, specify an image family for your new image, or configure specific encryption settings for the image.
|
||||
1. Click the **Source** menu and select **Cloud Storage file**.
|
||||
1. Enter the Trains Server image bucket path (see [Trains Server GCP Custom Image](#released-versions)), for example:
|
||||
`allegro-files/trains-server/trains-server.tar.gz`
|
||||
1. Click the **Create** button to import the image. The process can take several minutes depending on the size of the boot disk image.
|
||||
|
||||
For more information see [Import the image to your custom images list](https://cloud.google.com/compute/docs/import/import-existing-image#import_image) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
|
||||
|
||||
## Launching an instance with a Custom Image
|
||||
|
||||
For instructions on launching an instance using a GCP Custom Image, see the [Manually importing virtual disks](https://cloud.google.com/compute/docs/import/import-existing-image#overview) in the [Compute Engine Documentation](https://cloud.google.com/compute/docs).
|
||||
For more information on Custom Images, see [Custom Images](https://cloud.google.com/compute/docs/images#custom_images) in the Compute Engine Documentation.
|
||||
|
||||
The minimum recommended requirements for Trains Server are:
|
||||
- 2 vCPUs
|
||||
- 7.5GB RAM
|
||||
|
||||
## Upgrading
|
||||
|
||||
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
|
||||
|
||||
## Released versions
|
||||
|
||||
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
|
||||
|
||||
### Latest version image (v0.14.1)
|
||||
|
||||
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz
|
||||
@@ -10,10 +10,14 @@ from flask_cors import CORS
|
||||
|
||||
from config import config
|
||||
|
||||
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def upload():
|
||||
@@ -54,12 +58,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--upload-folder",
|
||||
"-u",
|
||||
default="/mnt/fileserver",
|
||||
default=DEFAULT_UPLOAD_FOLDER,
|
||||
help="Upload folder (default %(default)s)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
if app.config.get("UPLOAD_FOLDER") is None:
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
|
||||
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
|
||||
|
||||
|
||||
1
server/api_version.py
Normal file
1
server/api_version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "2.8.0"
|
||||
@@ -47,6 +47,7 @@ _error_codes = {
|
||||
128: ('invalid_task_output', 'invalid task output'),
|
||||
129: ('task_publish_in_progress', 'Task publish in progress'),
|
||||
130: ('task_not_found', 'task not found'),
|
||||
131: ('events_not_added', 'events not added'),
|
||||
|
||||
# Models
|
||||
200: ('model_error', 'general task error'),
|
||||
@@ -89,6 +90,8 @@ _error_codes = {
|
||||
1003: ('worker_registered', 'worker is already registered'),
|
||||
1004: ('worker_not_registered', 'worker is not registered'),
|
||||
1005: ('worker_stats_not_found', 'worker stats not found'),
|
||||
|
||||
1104: ('invalid_scroll_id', 'Invalid scroll id'),
|
||||
},
|
||||
|
||||
(401, 'unauthorized'): {
|
||||
|
||||
@@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apierrors import errors
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
@@ -168,7 +169,7 @@ class ActualEnumField(fields.StringField):
|
||||
validator_cls = EnumValidator if required else NullableEnumValidator
|
||||
validators = [*(validators or []), validator_cls(*choices)]
|
||||
super().__init__(
|
||||
default=default and self.parse_value(default),
|
||||
default=self.parse_value(default) if default else NotSet,
|
||||
*args,
|
||||
required=required,
|
||||
validators=validators,
|
||||
@@ -206,10 +207,10 @@ class DomainField(fields.StringField):
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
|
||||
|
||||
@@ -17,4 +20,52 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(items_types=str)
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
|
||||
|
||||
class MetricEvents(Base):
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
class DebugImageResponse(Base):
|
||||
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class TaskMetricsRequest(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
10
server/apimodels/organization.py
Normal file
10
server/apimodels/organization.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
@@ -92,6 +92,10 @@ class PingRequest(TaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
@@ -100,6 +104,7 @@ class CloneRequest(TaskRequest):
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
@@ -109,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsResponse(models.Base):
|
||||
added = ListField([str])
|
||||
updated = ListField([str])
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
@@ -13,7 +12,7 @@ from jsonmodels.fields import (
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import make_default, ListField, EnumField
|
||||
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
@@ -61,7 +60,7 @@ class IdNameEntry(Base):
|
||||
name = StringField()
|
||||
|
||||
|
||||
class WorkerEntry(Base):
|
||||
class WorkerEntry(Base, JsonSerializableMixin):
|
||||
key = StringField() # not required due to migration issues
|
||||
id = StringField(required=True)
|
||||
user = EmbeddedField(IdNameEntry)
|
||||
@@ -75,13 +74,6 @@ class WorkerEntry(Base):
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**json.loads(s))
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
running_time = IntField()
|
||||
|
||||
462
server/bll/event/debug_images_iterator.py
Normal file
462
server/bll/event/debug_images_iterator.py
Normal file
@@ -0,0 +1,462 @@
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.metrics import MetricEventStats
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
name: str = StringField(required=True)
|
||||
recycle_url_marker: str = StringField()
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
name: str = StringField(required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
timestamp: int = IntField(default=0)
|
||||
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@property
|
||||
def _max_workers(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugImagesResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
es_index=es_index,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
),
|
||||
state.metrics,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
since their states were initialized and reinits these states
|
||||
"""
|
||||
task_ids = set(metric.task for metric in state.metrics)
|
||||
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
|
||||
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return []
|
||||
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
|
||||
)
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE in stats.event_stats_by_type
|
||||
]
|
||||
|
||||
update_times = dict(
|
||||
chain.from_iterable(
|
||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
)
|
||||
)
|
||||
outdated_metrics = [
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if (metric.task, metric.name) in update_times
|
||||
and update_times[metric.task, metric.name] > metric.timestamp
|
||||
]
|
||||
state.metrics = [
|
||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||
*(
|
||||
self._init_metric_states(
|
||||
es_index,
|
||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
self, es_index, metrics: Sequence[Tuple[str, str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
tasks = defaultdict(list)
|
||||
for (task, metric) in metrics:
|
||||
tasks[task].append(metric)
|
||||
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
pool.map(
|
||||
partial(self._init_metric_states_for_task, es_index=es_index),
|
||||
tasks.items(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Sequence[str]], es_index
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_scroll_state(variant: dict):
|
||||
"""
|
||||
Return new variant scroll state for the passed variant bucket
|
||||
If the image urls get recycled then fill the last_invalid_iteration field
|
||||
"""
|
||||
state = VariantScrollState(name=variant["key"])
|
||||
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricScrollState(
|
||||
task=task,
|
||||
name=metric["key"],
|
||||
variants=[
|
||||
init_variant_scroll_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
metric: MetricScrollState,
|
||||
es_index: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update metric scroll state
|
||||
"""
|
||||
if metric.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and metric.last_min_iter is not None:
|
||||
range_condition = {"lt": metric.last_min_iter}
|
||||
elif not navigate_earlier and metric.last_max_iter is not None:
|
||||
range_condition = {"gt": metric.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
if navigate_earlier:
|
||||
"""
|
||||
When navigating to earlier iterations consider only
|
||||
variants whose invalid iterations border is lower than
|
||||
our starting iteration. For these variants make sure
|
||||
that only events from the valid iterations are returned
|
||||
"""
|
||||
if not metric.last_min_iter:
|
||||
variants = metric.variants
|
||||
else:
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is None
|
||||
or v.last_invalid_iteration < metric.last_min_iter
|
||||
)
|
||||
if not variants:
|
||||
return metric.task, metric.name, []
|
||||
must_conditions.append(
|
||||
{"terms": {"variant": list(v.name for v in variants)}}
|
||||
)
|
||||
else:
|
||||
"""
|
||||
When navigating to later iterations all variants may be relevant.
|
||||
For the variants whose invalid border is higher than our starting
|
||||
iteration make sure that only events from valid iterations are returned
|
||||
"""
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is not None
|
||||
and v.last_invalid_iteration > metric.last_max_iter
|
||||
)
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
if v.last_invalid_iteration is not None
|
||||
]
|
||||
if variants_conditions:
|
||||
must_not_conditions.append({"bool": {"should": variants_conditions}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {"must": must_conditions, "must_not": must_not_conditions}
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {"sort": {"url": {"order": "desc"}}}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
|
||||
return [
|
||||
ev["_source"]
|
||||
for v in variant_buckets
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
]
|
||||
|
||||
iterations = [
|
||||
{
|
||||
"iter": it["key"],
|
||||
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
|
||||
}
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets")
|
||||
]
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
metric.last_max_iter = iterations[0]["iter"]
|
||||
metric.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
# Commented for now since the last invalid iteration is calculated in the beginning
|
||||
# if navigate_earlier and any(
|
||||
# variant.last_invalid_iteration is None for variant in variants
|
||||
# ):
|
||||
# """
|
||||
# Variants validation flags due to recycling can
|
||||
# be set only on navigation to earlier frames
|
||||
# """
|
||||
# iterations = self._update_variants_invalid_iterations(variants, iterations)
|
||||
|
||||
return metric.task, metric.name, iterations
|
||||
|
||||
@staticmethod
|
||||
def _update_variants_invalid_iterations(
|
||||
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
This code is currently not in used since the invalid iterations
|
||||
are calculated during MetricState initialization
|
||||
For variants that do not have recycle url marker set it from the
|
||||
first event
|
||||
For variants that do not have last_invalid_iteration set check if the
|
||||
recycle marker was reached on a certain iteration and set it to the
|
||||
corresponding iteration
|
||||
For variants that have a newly set last_invalid_iteration remove
|
||||
events from the invalid iterations
|
||||
Return the updated iterations list
|
||||
"""
|
||||
variants_lookup = bucketize(variants, attrgetter("name"))
|
||||
for it in iterations:
|
||||
iteration = it["iter"]
|
||||
events_to_remove = []
|
||||
for event in it["events"]:
|
||||
variant = variants_lookup[event["variant"]][0]
|
||||
if (
|
||||
variant.last_invalid_iteration
|
||||
and variant.last_invalid_iteration >= iteration
|
||||
):
|
||||
events_to_remove.append(event)
|
||||
continue
|
||||
event_url = event.get("url")
|
||||
if not variant.recycle_url_marker:
|
||||
variant.recycle_url_marker = event_url
|
||||
elif variant.recycle_url_marker == event_url:
|
||||
variant.last_invalid_iteration = iteration
|
||||
events_to_remove.append(event)
|
||||
if events_to_remove:
|
||||
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||
return [it for it in iterations if it["events"]]
|
||||
@@ -2,11 +2,9 @@ import hashlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from operator import attrgetter
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Set, Tuple
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
@@ -15,69 +13,92 @@ from nested_dict import nested_dict
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from bll.event.event_metrics import EventMetrics, EventType
|
||||
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
|
||||
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
@attr.s
|
||||
class TaskEventsResult(object):
|
||||
events = attr.ib(type=list, default=attr.Factory(list))
|
||||
total_events = attr.ib(type=int, default=0)
|
||||
next_scroll_id = attr.ib(type=str, default=None)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
self._metrics = EventMetrics(self.es)
|
||||
self._skip_iteration_for_metric = set(config.get("services.events.ignore_iteration.metrics", []))
|
||||
self._skip_iteration_for_metric = set(
|
||||
config.get("services.events.ignore_iteration.metrics", [])
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
return self._metrics
|
||||
|
||||
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
|
||||
@staticmethod
|
||||
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not task_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_events = nested_dict(
|
||||
task_last_scalar_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> variant_hash -> MetricEvent
|
||||
|
||||
task_last_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
errors_per_type = defaultdict(int)
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
if "type" not in event:
|
||||
raise errors.BadRequest("Event must have a 'type' field", event=event)
|
||||
event_type = event.get("type")
|
||||
if event_type is None:
|
||||
errors_per_type["Event must have a 'type' field"] += 1
|
||||
continue
|
||||
|
||||
event_type = event["type"].replace(" ", "_")
|
||||
event_type = event_type.replace(" ", "_")
|
||||
if event_type not in EVENT_TYPES:
|
||||
raise errors.BadRequest(
|
||||
"Invalid event type {}".format(event_type),
|
||||
event=event,
|
||||
types=EVENT_TYPES,
|
||||
)
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
|
||||
@@ -106,6 +127,9 @@ class EventBLL(object):
|
||||
event["value"] = event["values"]
|
||||
del event["values"]
|
||||
|
||||
event["metric"] = event.get("metric") or ""
|
||||
event["variant"] = event.get("variant") or ""
|
||||
|
||||
index_name = EventMetrics.get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
@@ -120,89 +144,82 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is not None:
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if iter is not None and event.get("metric") not in self._skip_iteration_for_metric:
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_metric_event_for_task(
|
||||
task_last_events=task_last_events, task_id=task_id, event=event
|
||||
)
|
||||
else:
|
||||
es_action["_routing"] = task_id
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
if task_ids:
|
||||
# verify task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
extra_msg = None
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
extra_msg = "or task published"
|
||||
res = Task.objects(query).only("id")
|
||||
if len(res) < len(task_ids):
|
||||
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
extra_msg, company=company_id, ids=invalid_task_ids
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
errors_in_bulk = []
|
||||
added = 0
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_in_bulk.append(info)
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
|
||||
# who's events were successful
|
||||
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
|
||||
added = min(added, len(actions))
|
||||
|
||||
return added, errors_in_bulk
|
||||
if not added:
|
||||
raise errors.bad_request.EventsNotAdded(**errors_per_type)
|
||||
|
||||
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update task_last_events structure for the provided task_id with the provided event details if this event is more
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
recent than the currently stored event for its metric/variant combination.
|
||||
|
||||
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||
key conflicts due to invalid characters and/or long field names.
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
@@ -213,13 +230,34 @@ class EventBLL(object):
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
last_events = task_last_events[task_id]
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
|
||||
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None):
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
recent than the currently stored event for its metric/event_type combination.
|
||||
last_events contains [metric_name -> event_type -> event]
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
event_type = event.get("type")
|
||||
if not (metric and event_type):
|
||||
return
|
||||
|
||||
timestamp = last_events[metric][event_type].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric][event_type] = event
|
||||
|
||||
def _update_task(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
now,
|
||||
iter_max=None,
|
||||
last_scalar_events=None,
|
||||
last_events=None,
|
||||
):
|
||||
"""
|
||||
Update task information in DB with aggregated results after handling event(s) related to this task.
|
||||
|
||||
@@ -232,15 +270,18 @@ class EventBLL(object):
|
||||
if iter_max is not None:
|
||||
fields["last_iteration_max"] = iter_max
|
||||
|
||||
if last_events:
|
||||
fields["last_values"] = list(
|
||||
if last_scalar_events:
|
||||
fields["last_scalar_values"] = list(
|
||||
flatten_nested_items(
|
||||
last_events,
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=["value", "metric", "variant"],
|
||||
)
|
||||
)
|
||||
|
||||
if last_events:
|
||||
fields["last_events"] = last_events
|
||||
|
||||
if not fields:
|
||||
return False
|
||||
|
||||
@@ -279,7 +320,9 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, scroll="1h", routing=task_id
|
||||
)
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
@@ -297,10 +340,16 @@ class EventBLL(object):
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant"},
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
@@ -499,8 +548,18 @@ class EventBLL(object):
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": 200},
|
||||
"aggs": {"variants": {"terms": {"field": "variant", "size": 200}}},
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
@@ -540,14 +599,14 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1000,
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": 1000,
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Callable, Iterable
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apierrors import errors
|
||||
@@ -20,10 +21,19 @@ from utilities import safe_get
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 100
|
||||
MAX_TASKS_COUNT = 50
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
@@ -62,6 +72,12 @@ class EventMetrics:
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||
raise errors.BadRequest(
|
||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
|
||||
len(task_ids),
|
||||
)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
@@ -97,6 +113,31 @@ class EventMetrics:
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _split_metrics_by_max_aggs_count(
|
||||
self, task_metrics: Sequence[TaskMetric]
|
||||
) -> Iterable[Sequence[TaskMetric]]:
|
||||
"""
|
||||
Return task metrics in groups where amount of task metrics in each group
|
||||
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
|
||||
variants while always preserving all their tasks in the same group
|
||||
"""
|
||||
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield task_metrics
|
||||
return
|
||||
|
||||
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
|
||||
groups = []
|
||||
for group in tm_grouped.values():
|
||||
groups.append(group)
|
||||
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield list(itertools.chain(*groups))
|
||||
groups = []
|
||||
|
||||
if groups:
|
||||
yield list(itertools.chain(*groups))
|
||||
|
||||
return
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
@@ -126,21 +167,25 @@ class EventMetrics:
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(len(intervals)) as pool:
|
||||
metrics = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
get_func, task_ids=task_ids, es_index=es_index, key=key
|
||||
),
|
||||
intervals,
|
||||
)
|
||||
intervals = list(
|
||||
itertools.chain.from_iterable(
|
||||
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
|
||||
for i, tms in intervals
|
||||
)
|
||||
)
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
|
||||
intervals,
|
||||
)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def _get_metric_intervals(
|
||||
@@ -310,7 +355,13 @@ class EventMetrics:
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {"terms": {"field": "task"}, "aggs": aggregation}
|
||||
"tasks": {
|
||||
"terms": {
|
||||
"field": "task",
|
||||
"size": self.MAX_TASKS_COUNT,
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -396,3 +447,50 @@ class EventMetrics:
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence[Tuple]:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [(tid, []) for tid in task_ids]
|
||||
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_concurrency) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
),
|
||||
task_ids,
|
||||
)
|
||||
return list(zip(task_ids, res))
|
||||
|
||||
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"type": event_type.value}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
|
||||
]
|
||||
|
||||
169
server/bll/event/log_events_iterator.py
Normal file
169
server/bll/event/log_events_iterator.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class LogEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
task: str = StringField(required=True)
|
||||
last_min_timestamp: Optional[int] = IntField()
|
||||
last_max_timestamp: Optional[int] = IntField()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state """
|
||||
self.last_min_timestamp = self.last_max_timestamp = None
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=LogEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> TaskEventsResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
def init_state(state_: LogEventsScrollState):
|
||||
state_.task = task_id
|
||||
|
||||
def validate_state(state_: LogEventsScrollState):
|
||||
"""
|
||||
Checks that the task id stored in the state
|
||||
is equal to the one passed with the current call
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if state_.task != task_id:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task stored in the state does not match the passed one",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
state_.reset()
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res = TaskEventsResult(next_scroll_id=state.id)
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
state=state,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
es_index,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
state: LogEventsScrollState,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": state.task}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if navigate_earlier and state.last_min_timestamp is not None:
|
||||
es_req["search_after"] = [state.last_min_timestamp]
|
||||
elif not navigate_earlier and state.last_max_timestamp is not None:
|
||||
es_req["search_after"] = [state.last_max_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
if navigate_earlier:
|
||||
state.last_max_timestamp = events[0]["timestamp"]
|
||||
state.last_min_timestamp = events[-1]["timestamp"]
|
||||
else:
|
||||
state.last_min_timestamp = events[0]["timestamp"]
|
||||
state.last_max_timestamp = events[-1]["timestamp"]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
hits = es_result["hits"]["hits"]
|
||||
if not hits or len(hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
|
||||
already_present_ids = set(ev["_id"] for ev in events)
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[
|
||||
*events,
|
||||
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
|
||||
],
|
||||
hits_total,
|
||||
)
|
||||
@@ -4,7 +4,7 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from apimodels import StringEnum
|
||||
from utilities.stringenum import StringEnum
|
||||
from bll.util import extract_properties_to_lists
|
||||
from config import config
|
||||
|
||||
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": interval,
|
||||
"interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": interval,
|
||||
"interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
|
||||
85
server/bll/organization/__init__.py
Normal file
85
server/bll/organization/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(
|
||||
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
|
||||
filter_str = "_".join(filter_) if filter_ else ""
|
||||
return f"{field}_{company}_{filter_str}"
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
query &= GetMixin.get_list_field_query("system_tags", filter_)
|
||||
|
||||
tags = set()
|
||||
for cls_ in (Task, Model):
|
||||
tags |= set(cls_.objects(query).distinct(field))
|
||||
return tags
|
||||
|
||||
def get_tags(
|
||||
self, company, include_system: bool = False, filter_: Sequence[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
Return the dictionary of tags per tags field name
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [
|
||||
self._tags_field,
|
||||
*([self._system_tags_field] if include_system else []),
|
||||
]
|
||||
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
|
||||
"""
|
||||
Updates system tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
if reset or tags is not None:
|
||||
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
|
||||
if reset or system_tags is not None:
|
||||
self.redis.delete(
|
||||
self._get_tags_cache_key(company, self._system_tags_field)
|
||||
)
|
||||
1
server/bll/project/__init__.py
Normal file
1
server/bll/project/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .project_bll import ProjectBLL
|
||||
33
server/bll/project/project_bll.py
Normal file
33
server/bll/project/project_bll.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
@@ -9,9 +9,12 @@ import es_factory
|
||||
from apierrors import errors
|
||||
from bll.queue.queue_metrics import QueueMetrics
|
||||
from bll.workers import WorkerBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.queue import Queue, Entry
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class QueueBLL(object):
|
||||
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
|
||||
@@ -189,9 +192,7 @@ class QueueBLL(object):
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(**query).modify(
|
||||
pop__entries=-1, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
|
||||
@@ -200,6 +201,11 @@ class QueueBLL(object):
|
||||
if not queue.entries:
|
||||
return
|
||||
|
||||
try:
|
||||
Queue.objects(**query).update(last_update=datetime.utcnow())
|
||||
except Exception:
|
||||
log.exception("Error while updating Queue.last_update")
|
||||
|
||||
return queue.entries[0]
|
||||
|
||||
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
|
||||
|
||||
79
server/bll/redis_cache_manager.py
Normal file
79
server/bll/redis_cache_manager.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
import database
|
||||
from timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _do_nothing(_: T):
|
||||
return
|
||||
|
||||
|
||||
class RedisCacheManager(Generic[T]):
|
||||
"""
|
||||
Class for store/retrieve of state objects from redis
|
||||
|
||||
self.state_class - class of the state
|
||||
self.redis - instance of redis
|
||||
self.expiration_interval - expiration interval in seconds
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
|
||||
):
|
||||
self.state_class = state_class
|
||||
self.redis = redis
|
||||
self.expiration_interval = expiration_interval
|
||||
|
||||
def set_state(self, state: T) -> None:
|
||||
redis_key = self._get_redis_key(state.id)
|
||||
with TimingContext("redis", "cache_set_state"):
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
|
||||
def get_state(self, state_id) -> Optional[T]:
|
||||
redis_key = self._get_redis_key(state_id)
|
||||
with TimingContext("redis", "cache_get_state"):
|
||||
response = self.redis.get(redis_key)
|
||||
if response:
|
||||
return self.state_class.from_json(response)
|
||||
|
||||
def delete_state(self, state_id) -> None:
|
||||
with TimingContext("redis", "cache_delete_state"):
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
):
|
||||
"""
|
||||
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
||||
If no then create a new one with randomly generated id
|
||||
Yield the state and write it back to redis once the user code block exits
|
||||
:param state_id: id of the state to retrieve
|
||||
:param init_state: user callback to init the newly created state
|
||||
If not passed then no init except for the id generation is done
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
self.set_state(state)
|
||||
@@ -280,7 +280,7 @@ class StatisticsReporter:
|
||||
]
|
||||
return {
|
||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||
for group in Task.aggregate(*pipeline)
|
||||
for group in Task.aggregate(pipeline)
|
||||
}
|
||||
|
||||
|
||||
|
||||
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
89
server/bll/task/non_responsive_tasks_watchdog.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apierrors import errors
|
||||
from bll.task import ChangeStatusRequest
|
||||
from config import config
|
||||
from database.model.task.task import TaskStatus, Task
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class NonResponsiveTasksWatchdog:
|
||||
threads = ThreadsManager()
|
||||
|
||||
class _Settings:
|
||||
"""
|
||||
Retrieves watchdog settings from the config file
|
||||
The properties are not cached so that the updates in
|
||||
the config file are reflected
|
||||
"""
|
||||
|
||||
_prefix = "services.tasks.non_responsive_tasks_watchdog"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return config.get(f"{self._prefix}.enabled", True)
|
||||
|
||||
@property
|
||||
def watch_interval_sec(self):
|
||||
return config.get(f"{self._prefix}.watch_interval_sec", 900)
|
||||
|
||||
@property
|
||||
def threshold_sec(self):
|
||||
return config.get(f"{self._prefix}.threshold_sec", 7200)
|
||||
|
||||
settings = _Settings()
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
stopped = cls.cleanup_tasks(
|
||||
threshold_sec=cls.settings.threshold_sec
|
||||
)
|
||||
log.info(f"{stopped} non-responsive tasks stopped")
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
sleep(watch_interval)
|
||||
|
||||
@classmethod
|
||||
def cleanup_tasks(cls, threshold_sec):
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(seconds=threshold_sec)
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(status__in=relevant_status, last_update__lt=ref_time).only(
|
||||
"id", "name", "status", "project", "last_update"
|
||||
)
|
||||
)
|
||||
log.info(f"{len(tasks)} non-responsive tasks found")
|
||||
if not tasks:
|
||||
return 0
|
||||
|
||||
err_count = 0
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
try:
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
except errors.bad_request.FailedChangingTaskStatus:
|
||||
err_count += 1
|
||||
|
||||
return len(tasks) - err_count
|
||||
@@ -1,22 +1,25 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from random import random
|
||||
from time import sleep
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
|
||||
import pymongo.results
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from six import string_types
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.metrics import EventStats, MetricEventStats
|
||||
from database.model.task.output import Output
|
||||
from database.model.task.task import (
|
||||
Task,
|
||||
@@ -25,25 +28,37 @@ from database.model.task.task import (
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
external_task_types,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from services.utils import validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
threads = ThreadsManager("TaskBLL")
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
@@ -165,9 +180,12 @@ class TaskBLL(object):
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
validate_tags(tags, system_tags)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
@@ -175,6 +193,8 @@ class TaskBLL(object):
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
@@ -197,29 +217,47 @@ class TaskBLL(object):
|
||||
system_tags=system_tags or [],
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task):
|
||||
assert isinstance(task, Task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project and not Project.get_for_writing(
|
||||
company=task.company, id=task.project
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
cls.validate_execution_model(task)
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
@@ -259,7 +297,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(*pipeline)
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
@@ -277,7 +315,8 @@ class TaskBLL(object):
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
last_events: Dict[str, Dict[str, dict]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
"""
|
||||
@@ -289,7 +328,8 @@ class TaskBLL(object):
|
||||
task's last iteration value.
|
||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||
if the current task's last iteration value is smaller than the provided value.
|
||||
:param last_values: Last reported metrics summary (value, metric, variant).
|
||||
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
|
||||
:param last_events: Last reported metrics summary (value, metric, event type).
|
||||
:param extra_updates: Extra task updates to include in this update call.
|
||||
:return:
|
||||
"""
|
||||
@@ -300,17 +340,33 @@ class TaskBLL(object):
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_values is not None:
|
||||
if last_scalar_values is not None:
|
||||
|
||||
def op_path(op, *path):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_values:
|
||||
for path, value in last_scalar_values:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
if path[-1] == "value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
return {
|
||||
event_type: EventStats(last_update=event["timestamp"])
|
||||
for event_type, event in metric_data.items()
|
||||
}
|
||||
|
||||
metric_stats = {
|
||||
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
|
||||
)
|
||||
for metric_key, metric_data in last_events.items()
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
Task.objects(id=task_id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **extra_updates
|
||||
)
|
||||
@@ -553,58 +609,6 @@ class TaskBLL(object):
|
||||
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@classmethod
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start_non_responsive_tasks_watchdog(cls):
|
||||
log = config.logger("non_responsive_tasks_watchdog")
|
||||
relevant_status = (TaskStatus.in_progress,)
|
||||
threshold = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.threshold_sec", 7200
|
||||
)
|
||||
)
|
||||
watch_interval = config.get(
|
||||
"services.tasks.non_responsive_tasks_watchdog.watch_interval_sec", 900
|
||||
)
|
||||
sleep(watch_interval)
|
||||
while not ThreadsManager.terminating:
|
||||
try:
|
||||
|
||||
ref_time = datetime.utcnow() - threshold
|
||||
|
||||
log.info(
|
||||
f"Starting cleanup cycle for running tasks last updated before {ref_time}"
|
||||
)
|
||||
|
||||
tasks = list(
|
||||
Task.objects(
|
||||
status__in=relevant_status, last_update__lt=ref_time
|
||||
).only("id", "name", "status", "project", "last_update")
|
||||
)
|
||||
|
||||
if tasks:
|
||||
|
||||
log.info(f"Stopping {len(tasks)} non-responsive tasks")
|
||||
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
|
||||
log.info(f"Done")
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
|
||||
sleep(watch_interval)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_execution_parameters(
|
||||
company_id,
|
||||
@@ -644,7 +648,7 @@ class TaskBLL(object):
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(*pipeline), None)
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
|
||||
@@ -33,8 +33,8 @@ log = config.logger(__file__)
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es if es is not None else es_factory.connect("workers")
|
||||
self.redis = redis if redis is not None else redman.connection("workers")
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@property
|
||||
@@ -223,7 +223,7 @@ class WorkerBLL:
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(*projection)
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
|
||||
@@ -47,7 +47,7 @@ class BasicConfig:
|
||||
def logger(self, name):
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self):
|
||||
@@ -57,7 +57,7 @@ class BasicConfig:
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
@@ -77,7 +77,7 @@ class BasicConfig:
|
||||
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
|
||||
]
|
||||
if invalid:
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(invalid)}")
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
|
||||
return [path for path in paths if path.is_dir()]
|
||||
|
||||
def _load(self, verbose=True):
|
||||
|
||||
@@ -34,6 +34,12 @@
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_file: "/path/to/export.zip"
|
||||
fail_on_error: false
|
||||
}
|
||||
}
|
||||
|
||||
auth {
|
||||
|
||||
@@ -32,6 +32,11 @@ mongo {
|
||||
}
|
||||
|
||||
redis {
|
||||
apiserver {
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
db: 0
|
||||
}
|
||||
workers {
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
|
||||
@@ -13,17 +13,21 @@
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
apiserver {
|
||||
role: "system"
|
||||
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
|
||||
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
|
||||
}
|
||||
webserver {
|
||||
role: "system"
|
||||
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
tests {
|
||||
role: "user"
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,4 +2,12 @@ es_index_prefix: "events"
|
||||
|
||||
ignore_iteration {
|
||||
metrics: [":monitor:machine", ":monitor:gpu"]
|
||||
}
|
||||
}
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
}
|
||||
|
||||
3
server/config/default/services/organization.conf
Normal file
3
server/config/default/services/organization.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
non_responsive_tasks_watchdog {
|
||||
enabled: true
|
||||
|
||||
# In-progress tasks older than this value in seconds will be stopped by the watchdog
|
||||
threshold_sec: 7200
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from mongoengine import (
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
@@ -125,17 +128,39 @@ def contains_empty_key(d):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
class DictValidationMixin:
|
||||
"""
|
||||
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||
"""
|
||||
|
||||
def _safe_validate(self: DictField, value):
|
||||
if not isinstance(value, dict):
|
||||
self.error("Only dictionaries may be used in a DictField")
|
||||
|
||||
if key_not_string(value):
|
||||
msg = "Invalid dictionary key - documents must have only string keys"
|
||||
self.error(msg)
|
||||
|
||||
if key_starts_with_dollar(value):
|
||||
self.error(
|
||||
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||
)
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
|
||||
class SafeMapField(MapField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
@@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
@@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
def key(v): return str(itemgetter(self._ordering)(v))
|
||||
|
||||
def key(v):
|
||||
return str(itemgetter(self._ordering)(v))
|
||||
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
@@ -43,6 +43,7 @@ class Role(object):
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
last_used = DateTimeField()
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
|
||||
from boltons.iterutils import first
|
||||
from boltons.iterutils import first, bucketize
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
@@ -34,7 +34,12 @@ class AuthDocument(Document):
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
||||
def to_proper_dict(
|
||||
self: Union["ProperDictMixin", Document],
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
@@ -71,6 +76,8 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -91,11 +98,48 @@ class GetMixin(PropsMixin):
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_next = _default
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
||||
cls: Union["GetMixin", Document],
|
||||
company,
|
||||
id,
|
||||
*,
|
||||
_only=None,
|
||||
include_public=False,
|
||||
**kwargs,
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
@@ -162,17 +206,7 @@ class GetMixin(PropsMixin):
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -216,6 +250,47 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
is included.
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
@@ -409,7 +484,12 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
@@ -460,6 +540,8 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
the None values are sorted in the end regardless of the sorting order.
|
||||
If the first order field is a user defined parameter (either from execution.parameters,
|
||||
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
@@ -500,6 +582,16 @@ class GetMixin(PropsMixin):
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
v
|
||||
for k, v in cls._field_collation_overrides.items()
|
||||
if order_field.startswith(k)
|
||||
)
|
||||
if collation_override:
|
||||
query_sets = [
|
||||
qs.collation(collation=collation_override) for qs in query_sets
|
||||
]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
@@ -593,7 +685,13 @@ class UpdateMixin(object):
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
||||
def safe_update(
|
||||
cls: Union["UpdateMixin", Document],
|
||||
company_id,
|
||||
id,
|
||||
partial_update_dict,
|
||||
injected_update=None,
|
||||
):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
@@ -610,7 +708,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
|
||||
@classmethod
|
||||
def aggregate(
|
||||
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
|
||||
cls: Union["DbModelMixin", Document],
|
||||
pipeline: Sequence[dict],
|
||||
allow_disk_use=None,
|
||||
**kwargs,
|
||||
) -> CommandCursor:
|
||||
"""
|
||||
Aggregate objects of this document class according to the provided pipeline.
|
||||
@@ -625,7 +726,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
)
|
||||
return cls.objects.aggregate(*pipeline, **kwargs)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
@@ -647,5 +748,5 @@ def validate_id(cls, company, **kwargs):
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeDictField
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
@@ -12,46 +13,61 @@ from database.model.user import User
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
'name': '%s.model.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$parent',
|
||||
'$task',
|
||||
'$project',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'parent': 5,
|
||||
'task': 3,
|
||||
'project': 3,
|
||||
}
|
||||
}
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"parent": 5,
|
||||
"task": 3,
|
||||
"project": 3,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"user",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field='Model', required=False)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default='', user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default="", user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from mongoengine import MapField, IntField
|
||||
from database.fields import NoneType, UnionField, SafeMapField
|
||||
|
||||
|
||||
class ModelLabels(MapField):
|
||||
class ModelLabels(SafeMapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
||||
super(ModelLabels, self).__init__(
|
||||
field=UnionField(types=(int, NoneType)), *args, **kwargs
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
if value and (len(set(value.values())) < len(value)):
|
||||
non_empty_values = list(filter(None, value.values()))
|
||||
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
from mongoengine import StringField, DateTimeField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
@@ -17,12 +17,13 @@ class Project(AttributedDocument):
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "name"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
"default_language": "english",
|
||||
"weights": {"name": 10, "id": 10, "description": 10},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@@ -35,7 +36,7 @@ class Project(AttributedDocument):
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True))
|
||||
system_tags = ListField(StringField(required=True))
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -4,11 +4,10 @@ from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -7,6 +7,10 @@ from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class SettingKeys:
|
||||
server__uuid = "server.uuid"
|
||||
|
||||
|
||||
class Settings(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
@@ -47,7 +51,7 @@ class Settings(DbModelMixin, Document):
|
||||
""" Adds a new key/value settings. Fails if key already exists. """
|
||||
key = key.strip(sep)
|
||||
try:
|
||||
res = Settings(key=key, value=value).save(force_insert=True)
|
||||
res = cls(key=key, value=value).save(force_insert=True)
|
||||
return bool(res)
|
||||
except NotUniqueError:
|
||||
return False
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
from mongoengine import EmbeddedDocument, StringField, DynamicField
|
||||
from mongoengine import (
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DynamicField,
|
||||
LongField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from database.fields import SafeMapField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
'strict': False,
|
||||
"strict": False,
|
||||
}
|
||||
|
||||
metric = StringField(required=True)
|
||||
@@ -12,3 +20,20 @@ class MetricEvent(EmbeddedDocument):
|
||||
value = DynamicField(required=True)
|
||||
min_value = DynamicField() # for backwards compatibility reasons
|
||||
max_value = DynamicField() # for backwards compatibility reasons
|
||||
|
||||
|
||||
class EventStats(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
last_update = LongField()
|
||||
|
||||
|
||||
class MetricEventStats(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
metric = StringField(required=True)
|
||||
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))
|
||||
|
||||
@@ -18,11 +18,11 @@ from database.fields import (
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import ProperDictMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
from .metrics import MetricEvent
|
||||
from .metrics import MetricEvent, MetricEventStats
|
||||
from .output import Output
|
||||
|
||||
DEFAULT_LAST_ITERATION = 0
|
||||
@@ -100,9 +100,26 @@ class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
class TaskType(object):
|
||||
training = "training"
|
||||
testing = "testing"
|
||||
inference = "inference"
|
||||
data_processing = "data_processing"
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
|
||||
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -110,6 +127,13 @@ class Task(AttributedDocument):
|
||||
"created",
|
||||
"started",
|
||||
"completed",
|
||||
"parent",
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
@@ -134,6 +158,12 @@ class Task(AttributedDocument):
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
@@ -152,13 +182,14 @@ class Task(AttributedDocument):
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
|
||||
@@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
||||
from mongoengine.queryset.visitor import (
|
||||
QueryCompilerVisitor,
|
||||
SimplificationVisitor,
|
||||
QCombination,
|
||||
QNode,
|
||||
)
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
@@ -17,17 +23,16 @@ class RegexWrapper(object):
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
|
||||
def to_query(self, document):
|
||||
def to_query(self: Union["RegexMixin", QNode], document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self, other, operation):
|
||||
def _combine(self: Union["RegexMixin", QNode], other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, 'empty', True):
|
||||
if getattr(other, "empty", True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
|
||||
@@ -95,21 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if callable(desc):
|
||||
desc(value)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(
|
||||
value, desc
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s" % desc.__name__, field=field
|
||||
)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only(
|
||||
"id"
|
||||
):
|
||||
if issubclass(desc, Document):
|
||||
if not desc.objects(id=value).only("id"):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
elif callable(desc):
|
||||
try:
|
||||
desc(value)
|
||||
except TypeError:
|
||||
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
||||
except Exception as ex:
|
||||
raise ParseCallError(str(ex), field=field)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
27
server/elastic/initialize.py
Normal file
27
server/elastic/initialize.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from furl import furl
|
||||
|
||||
from config import config
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
log.info(res)
|
||||
@@ -1,222 +0,0 @@
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
from mongoengine.connection import get_db
|
||||
from semantic_version import Version
|
||||
|
||||
import database.utils
|
||||
from bll.queue import QueueBLL
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database import Database
|
||||
from database.model.auth import Role
|
||||
from database.model.auth import User as AuthUser, Credentials
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings
|
||||
from database.model.user import User
|
||||
from database.model.version import Version as DatabaseVersion
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
migration_dir = Path(__file__).resolve().parent / "mongo" / "migrations"
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
log.info(res)
|
||||
|
||||
|
||||
def _ensure_company():
|
||||
company_id = get_default_company()
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
return company_id
|
||||
|
||||
|
||||
def _ensure_default_queue(company):
|
||||
"""
|
||||
If no queue is present for the company then
|
||||
create a new one and mark it as a default
|
||||
"""
|
||||
queue = Queue.objects(company=company).only("id").first()
|
||||
if queue:
|
||||
return
|
||||
|
||||
QueueBLL.create(company, name="default", system_tags=["default"])
|
||||
|
||||
|
||||
def _ensure_auth_user(user_data, company_id):
|
||||
ensure_credentials = {"key", "secret"}.issubset(user_data.keys())
|
||||
if ensure_credentials:
|
||||
user = AuthUser.objects(
|
||||
credentials__match=Credentials(
|
||||
key=user_data["key"], secret=user_data["secret"]
|
||||
)
|
||||
).first()
|
||||
if user:
|
||||
return user.id
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
user = AuthUser(
|
||||
id=user_data.get("id", f"__{user_data['name']}__"),
|
||||
name=user_data["name"],
|
||||
company=company_id,
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
|
||||
if ensure_credentials
|
||||
else None,
|
||||
)
|
||||
|
||||
user.save()
|
||||
|
||||
return user.id
|
||||
|
||||
|
||||
def _ensure_user(user: FixedUser, company_id: str):
|
||||
if User.objects(id=user.user_id).first():
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user.user_id,
|
||||
company=company_id,
|
||||
name=user.name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
|
||||
|
||||
def _apply_migrations():
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
|
||||
try:
|
||||
new_scripts = {
|
||||
ver: path
|
||||
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
|
||||
if ver > last_version
|
||||
}
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Failed parsing migration version from file: {ex}")
|
||||
|
||||
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
|
||||
|
||||
migration_log = log.getChild("mongodb_migration")
|
||||
|
||||
for script_version in sorted(new_scripts.keys()):
|
||||
script = new_scripts[script_version]
|
||||
spec = importlib.util.spec_from_file_location(script.stem, str(script))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for alias, func_name in dbs.items():
|
||||
func = getattr(module, func_name, None)
|
||||
if not func:
|
||||
continue
|
||||
try:
|
||||
migration_log.info(f"Applying {script.stem}/{func_name}()")
|
||||
func(get_db(alias))
|
||||
except Exception:
|
||||
migration_log.exception(f"Failed applying {script}:{func_name}()")
|
||||
raise ValueError("Migration failed, aborting. Please restore backup.")
|
||||
|
||||
DatabaseVersion(
|
||||
id=database.utils.id(),
|
||||
num=script.stem,
|
||||
created=datetime.utcnow(),
|
||||
desc="Applied on server startup",
|
||||
).save()
|
||||
|
||||
|
||||
def _ensure_uuid():
|
||||
Settings.add_value("server.uuid", str(uuid4()))
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
_apply_migrations()
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company()
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
users = [
|
||||
{
|
||||
"name": "apiserver",
|
||||
"role": Role.system,
|
||||
"email": "apiserver@example.com",
|
||||
},
|
||||
{
|
||||
"name": "webserver",
|
||||
"role": Role.system,
|
||||
"email": "webserver@example.com",
|
||||
},
|
||||
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
|
||||
]
|
||||
|
||||
for user in users:
|
||||
credentials = config.get(f"secure.credentials.{user['name']}")
|
||||
user["key"] = credentials.user_key
|
||||
user["secret"] = credentials.user_secret
|
||||
_ensure_auth_user(user, company_id)
|
||||
|
||||
if FixedUser.enabled():
|
||||
log.info("Fixed users mode is enabled")
|
||||
FixedUser.validate()
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
_ensure_user(user, company_id)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
except Exception as ex:
|
||||
log.exception("Failed initializing mongodb")
|
||||
65
server/mongo/initialize/__init__.py
Normal file
65
server/mongo/initialize/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from pathlib import Path
|
||||
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
from .migration import _apply_migrations
|
||||
from .pre_populate import PrePopulate
|
||||
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
|
||||
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
|
||||
|
||||
log = config.logger(__package__)
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
empty_dbs = _apply_migrations(log)
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company(log)
|
||||
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
|
||||
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
|
||||
if not zip_file or not Path(zip_file).is_file():
|
||||
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
|
||||
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
|
||||
log.error(msg)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
log.warning(msg)
|
||||
else:
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
|
||||
PrePopulate.import_from_zip(zip_file, user_id=user_id)
|
||||
|
||||
fixed_mode = FixedUser.enabled()
|
||||
|
||||
for user, credentials in config.get("secure.credentials", {}).items():
|
||||
user_data = {
|
||||
"name": user,
|
||||
"role": credentials.role,
|
||||
"email": f"{user}@example.com",
|
||||
"key": credentials.user_key,
|
||||
"secret": credentials.user_secret,
|
||||
}
|
||||
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
|
||||
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
|
||||
if credentials.role == Role.user:
|
||||
_ensure_backend_user(user_id, company_id, credentials.display_name)
|
||||
|
||||
if fixed_mode:
|
||||
log.info("Fixed users mode is enabled")
|
||||
FixedUser.validate()
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
ensure_fixed_user(user, company_id, log=log)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
except Exception as ex:
|
||||
log.exception("Failed initializing mongodb")
|
||||
86
server/mongo/initialize/migration.py
Normal file
86
server/mongo/initialize/migration.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from mongoengine.connection import get_db
|
||||
from semantic_version import Version
|
||||
|
||||
import database.utils
|
||||
from database import Database
|
||||
from database.model.version import Version as DatabaseVersion
|
||||
|
||||
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
|
||||
|
||||
|
||||
def _apply_migrations(log: Logger) -> bool:
|
||||
"""
|
||||
Apply migrations as found in the migration dir.
|
||||
Returns a boolean indicating whether the database was empty prior to migration.
|
||||
"""
|
||||
log = log.getChild(Path(__file__).stem)
|
||||
|
||||
log.info(f"Started mongodb migrations")
|
||||
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
|
||||
empty_dbs = not any(
|
||||
get_db(alias).collection_names()
|
||||
for alias in database.utils.get_options(Database)
|
||||
)
|
||||
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
|
||||
try:
|
||||
new_scripts = {
|
||||
ver: path
|
||||
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
|
||||
if ver > last_version
|
||||
}
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Failed parsing migration version from file: {ex}")
|
||||
|
||||
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
|
||||
|
||||
for script_version in sorted(new_scripts):
|
||||
script = new_scripts[script_version]
|
||||
|
||||
if empty_dbs:
|
||||
log.info(f"Skipping migration {script.name} (empty databases)")
|
||||
else:
|
||||
spec = importlib.util.spec_from_file_location(script.stem, str(script))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for alias, func_name in dbs.items():
|
||||
func = getattr(module, func_name, None)
|
||||
if not func:
|
||||
continue
|
||||
try:
|
||||
log.info(f"Applying {script.stem}/{func_name}()")
|
||||
func(get_db(alias))
|
||||
except Exception:
|
||||
log.exception(f"Failed applying {script}:{func_name}()")
|
||||
raise ValueError(
|
||||
"Migration failed, aborting. Please restore backup."
|
||||
)
|
||||
|
||||
DatabaseVersion(
|
||||
id=database.utils.id(),
|
||||
num=script.stem,
|
||||
created=datetime.utcnow(),
|
||||
desc="Applied on server startup",
|
||||
).save()
|
||||
|
||||
log.info("Finished mongodb migrations")
|
||||
|
||||
return empty_dbs
|
||||
153
server/mongo/initialize/pre_populate.py
Normal file
153
server/mongo/initialize/pre_populate.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import importlib
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from os.path import splitext
|
||||
from typing import List, Optional, Any, Type, Set, Dict
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import mongoengine
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@classmethod
|
||||
def export_to_zip(
|
||||
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
|
||||
cls._export(zfile, experiments, projects)
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(cls, filename: str, user_id: str = None):
|
||||
with ZipFile(filename) as zfile:
|
||||
cls._import(zfile, user_id)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
cls: Type[mongoengine.Document], ids: Optional[List[str]]
|
||||
) -> List[Any]:
|
||||
ids = set(ids)
|
||||
items = list(cls.objects(id__in=list(ids)))
|
||||
resolved = {i.id for i in items}
|
||||
missing = ids - resolved
|
||||
for name_candidate in missing:
|
||||
results = list(cls.objects(name=name_candidate))
|
||||
if not results:
|
||||
print(f"ERROR: no match for `{name_candidate}`")
|
||||
exit(1)
|
||||
elif len(results) > 1:
|
||||
print(f"ERROR: more than one match for `{name_candidate}`")
|
||||
exit(1)
|
||||
items.append(results[0])
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def _resolve_entities(
|
||||
cls, experiments: List[str] = None, projects: List[str] = None
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
|
||||
entities = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
entities[Project].update(cls._resolve_type(Project, projects))
|
||||
print("--> Reading project experiments...")
|
||||
objs = Task.objects(
|
||||
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
|
||||
)
|
||||
entities[Task].update(o for o in objs if o.id not in (experiments or []))
|
||||
|
||||
if experiments:
|
||||
print("Reading experiments...")
|
||||
entities[Task].update(cls._resolve_type(Task, experiments))
|
||||
print("--> Reading experiments projects...")
|
||||
objs = Project.objects(
|
||||
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
|
||||
)
|
||||
project_ids = {p.id for p in entities[Project]}
|
||||
entities[Project].update(o for o in objs if o.id not in project_ids)
|
||||
|
||||
return entities
|
||||
|
||||
@classmethod
|
||||
def _cleanup_task(cls, task):
|
||||
from database.model.task.task import TaskStatus
|
||||
|
||||
task.completed = None
|
||||
task.started = None
|
||||
if task.execution:
|
||||
task.execution.model = None
|
||||
task.execution.model_desc = None
|
||||
task.execution.model_labels = None
|
||||
if task.output:
|
||||
task.output.model = None
|
||||
|
||||
task.status = TaskStatus.created
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.created = datetime.utcnow()
|
||||
task.last_iteration = 0
|
||||
task.last_update = task.created
|
||||
task.status_changed = task.created
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
from database.model.task.task import Task
|
||||
if entity_cls == Task:
|
||||
cls._cleanup_task(entity)
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
|
||||
):
|
||||
entities = cls._resolve_entities(experiments, projects)
|
||||
|
||||
for cls_, items in entities.items():
|
||||
if not items:
|
||||
continue
|
||||
filename = f"{cls_.__module__}.{cls_.__name__}.json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with writer.open(filename, "w") as f:
|
||||
f.write("[\n".encode("utf-8"))
|
||||
last = len(items) - 1
|
||||
for i, item in enumerate(items):
|
||||
cls._cleanup_entity(cls_, item)
|
||||
f.write(item.to_json().encode("utf-8"))
|
||||
if i != last:
|
||||
f.write(",".encode("utf-8"))
|
||||
f.write("\n".encode("utf-8"))
|
||||
f.write("]\n".encode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def _import(reader: ZipFile, user_id: str = None):
|
||||
for file_info in reader.filelist:
|
||||
full_name = splitext(file_info.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
|
||||
with reader.open(file_info) as f:
|
||||
for item in tqdm(
|
||||
f.readlines(),
|
||||
desc=f"Writing {cls_.__name__.lower()}s into database",
|
||||
unit="doc",
|
||||
):
|
||||
item = (
|
||||
item.decode("utf-8")
|
||||
.strip()
|
||||
.lstrip("[")
|
||||
.rstrip("]")
|
||||
.rstrip(",")
|
||||
.strip()
|
||||
)
|
||||
if not item:
|
||||
continue
|
||||
doc = cls_.from_json(item)
|
||||
if user_id is not None and hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.save(force_insert=True)
|
||||
79
server/mongo/initialize/user.py
Normal file
79
server/mongo/initialize/user.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
|
||||
import attr
|
||||
|
||||
from database.model.auth import Role
|
||||
from database.model.auth import User as AuthUser, Credentials
|
||||
from database.model.user import User
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
|
||||
ensure_credentials = {"key", "secret"}.issubset(user_data)
|
||||
if ensure_credentials:
|
||||
user = AuthUser.objects(
|
||||
credentials__match=Credentials(
|
||||
key=user_data["key"], secret=user_data["secret"]
|
||||
)
|
||||
).first()
|
||||
if user:
|
||||
if revoke:
|
||||
user.credentials = []
|
||||
user.save()
|
||||
return user.id
|
||||
|
||||
user_id = user_data.get("id", f"__{user_data['name']}__")
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
user = AuthUser(
|
||||
id=user_id,
|
||||
name=user_data["name"],
|
||||
company=company_id,
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])] if not revoke else []
|
||||
if ensure_credentials
|
||||
else None,
|
||||
)
|
||||
|
||||
user.save()
|
||||
|
||||
return user.id
|
||||
|
||||
|
||||
def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
given_name, _, family_name = user_name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user_id,
|
||||
company=company_id,
|
||||
name=user_name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
|
||||
if User.objects(id=user.user_id).first():
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user.user_id,
|
||||
company=company_id,
|
||||
name=user.name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
40
server/mongo/initialize/util.py
Normal file
40
server/mongo/initialize/util.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from logging import Logger
|
||||
from uuid import uuid4
|
||||
|
||||
from bll.queue import QueueBLL
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings, SettingKeys
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def _ensure_company(log: Logger):
|
||||
company_id = get_default_company()
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
return company_id
|
||||
|
||||
|
||||
def _ensure_default_queue(company):
|
||||
"""
|
||||
If no queue is present for the company then
|
||||
create a new one and mark it as a default
|
||||
"""
|
||||
queue = Queue.objects(company=company).only("id").first()
|
||||
if queue:
|
||||
return
|
||||
|
||||
QueueBLL.create(company, name="default", system_tags=["default"])
|
||||
|
||||
|
||||
def _ensure_uuid():
|
||||
Settings.add_value(SettingKeys.server__uuid, str(uuid4()))
|
||||
46
server/mongo/migrations/0.14.0.py
Normal file
46
server/mongo/migrations/0.14.0.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import hashlib
|
||||
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
def _get_ids():
|
||||
if not FixedUser.enabled():
|
||||
return
|
||||
|
||||
return {
|
||||
hashlib.md5(f"{user.username}:{user.password}".encode()).hexdigest(): user.user_id
|
||||
for user in FixedUser.from_config()
|
||||
}
|
||||
|
||||
|
||||
def _switch_uuid(collection: Collection, uuid_field: str, uuids: dict):
|
||||
docs = list(collection.find({uuid_field: {"$in": [uuids]}}))
|
||||
if not docs:
|
||||
return
|
||||
replaced_uuids = [doc[uuid_field] for doc in docs]
|
||||
for doc in docs:
|
||||
doc[uuid_field] = uuids[doc[uuid_field]]
|
||||
collection.insert_many(docs)
|
||||
collection.delete_many({uuid_field: {"$in": replaced_uuids}})
|
||||
|
||||
|
||||
def migrate_auth(db: Database):
|
||||
uuids = _get_ids()
|
||||
if not uuids:
|
||||
return
|
||||
|
||||
collection = db["user"]
|
||||
collection.drop_index("name_1_company_1")
|
||||
|
||||
_switch_uuid(collection=collection, uuid_field="_id", uuids=uuids)
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
uuids = _get_ids()
|
||||
if not uuids:
|
||||
return
|
||||
|
||||
for name in ("project", "task", "model"):
|
||||
_switch_uuid(collection=db[name], uuid_field="user", uuids=uuids)
|
||||
58
server/mongo/migrations/0.15.0.py
Normal file
58
server/mongo/migrations/0.15.0.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections import Collection
|
||||
from typing import Sequence
|
||||
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
for collection_name in db.list_collection_names():
|
||||
if collection_name not in names:
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
||||
|
||||
|
||||
def migrate_auth(db: Database):
|
||||
"""
|
||||
Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
_drop_all_indices_from_collections(db, ["user"])
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
"""
|
||||
1. Sort tags and system tags
|
||||
2. Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
|
||||
fields = ("tags", "system_tags")
|
||||
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
|
||||
for collection_name in ("task", "model", "project", "queue"):
|
||||
collection = db[collection_name]
|
||||
for doc in collection.find(filter=query, projection=fields):
|
||||
update = {
|
||||
field: sorted(doc[field])
|
||||
for field in fields
|
||||
if doc.get(field)
|
||||
}
|
||||
if update:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": update})
|
||||
|
||||
_drop_all_indices_from_collections(
|
||||
db,
|
||||
[
|
||||
"company",
|
||||
"model",
|
||||
"project",
|
||||
"queue",
|
||||
"settings",
|
||||
"task",
|
||||
"task__trash",
|
||||
"user",
|
||||
"versions",
|
||||
],
|
||||
)
|
||||
@@ -1,29 +1,30 @@
|
||||
six
|
||||
Flask>=0.12.2
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
pyhocon>=0.3.35
|
||||
requests>=2.13.0
|
||||
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
|
||||
Flask-Cors>=3.0.5
|
||||
Flask-Compress>=1.4.0
|
||||
mongoengine==0.16.2
|
||||
jsonmodels>=2.3
|
||||
pyjwt>=1.3.0
|
||||
gunicorn>=19.7.1
|
||||
Jinja2==2.10
|
||||
python-rapidjson>=0.6.3
|
||||
jsonschema>=2.6.0
|
||||
dpath>=1.4.2
|
||||
funcsigs==1.0.2
|
||||
luqum>=0.7.2
|
||||
attrs>=19.1.0
|
||||
nested_dict>=1.61
|
||||
related>=0.7.2
|
||||
validators>=0.12.4
|
||||
fastjsonschema>=2.8
|
||||
boltons>=19.1.0
|
||||
semantic_version>=2.6.0,<3
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
fastjsonschema>=2.8
|
||||
Flask-Compress>=1.4.0
|
||||
Flask-Cors>=3.0.5
|
||||
Flask>=0.12.2
|
||||
funcsigs==1.0.2
|
||||
furl>=2.0.0
|
||||
redis>=2.10.5
|
||||
gunicorn>=19.7.1
|
||||
humanfriendly==4.18
|
||||
Jinja2==2.10
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.7.2
|
||||
mongoengine==0.19.1
|
||||
nested_dict>=1.61
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt>=1.3.0
|
||||
pymongo==3.10.1
|
||||
python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.0,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
||||
@@ -171,6 +171,30 @@
|
||||
critical
|
||||
]
|
||||
}
|
||||
event_type_enum {
|
||||
type: string
|
||||
enum: [
|
||||
training_stats_scalar
|
||||
training_stats_vector
|
||||
training_debug_image
|
||||
plot
|
||||
log
|
||||
]
|
||||
}
|
||||
task_metric {
|
||||
type: object
|
||||
required: [task, metric]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_log_event {
|
||||
description: """A log event associated with a task."""
|
||||
type: object
|
||||
@@ -234,6 +258,7 @@
|
||||
properties {
|
||||
added { type: integer }
|
||||
errors { type: integer }
|
||||
errors_info { type: object }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -319,6 +344,84 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.7" {
|
||||
description: "Get the debug image events for the requested amount of iterations per each task's metric"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
metrics
|
||||
]
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
items { "$ref": "#/definitions/task_metric" }
|
||||
description: "List metrics for which the envents will be retreived"
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
|
||||
}
|
||||
refresh {
|
||||
type: boolean
|
||||
description: "If set then scroll will be moved to the latest iterations. The default is False"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
items: { type: object }
|
||||
description: "Debug image events grouped by task metrics and iterations"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_metrics{
|
||||
"2.7": {
|
||||
description: "For each task, get a list of metrics for which the requested event type was reported"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
tasks
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
type: array
|
||||
items { type: string }
|
||||
description: "Task IDs"
|
||||
}
|
||||
event_type {
|
||||
"description": "Event type"
|
||||
"$ref": "#/definitions/event_type_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
items { type: object }
|
||||
description: "List of task with their metrics"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_log {
|
||||
"1.5" {
|
||||
@@ -427,6 +530,59 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
// "2.7" {
|
||||
// description: "Get 'log' events for this task"
|
||||
// request {
|
||||
// type: object
|
||||
// required: [
|
||||
// task
|
||||
// ]
|
||||
// properties {
|
||||
// task {
|
||||
// type: string
|
||||
// description: "Task ID"
|
||||
// }
|
||||
// batch_size {
|
||||
// type: integer
|
||||
// description: "The amount of log events to return"
|
||||
// }
|
||||
// navigate_earlier {
|
||||
// type: boolean
|
||||
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
||||
// }
|
||||
// refresh {
|
||||
// type: boolean
|
||||
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
||||
// }
|
||||
// scroll_id {
|
||||
// type: string
|
||||
// description: "Scroll ID of previous call (used for getting more results)"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// response {
|
||||
// type: object
|
||||
// properties {
|
||||
// events {
|
||||
// type: array
|
||||
// items { type: object }
|
||||
// description: "Log items list"
|
||||
// }
|
||||
// returned {
|
||||
// type: integer
|
||||
// description: "Number of log events returned"
|
||||
// }
|
||||
// total {
|
||||
// type: number
|
||||
// description: "Total number of log events available for this query"
|
||||
// }
|
||||
// scroll_id {
|
||||
// type: string
|
||||
// description: "Scroll ID for getting more results"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
get_task_events {
|
||||
"2.1" {
|
||||
|
||||
@@ -159,6 +159,11 @@
|
||||
description: "Get only models whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "List of user IDs used to filter results by the model's creating user"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
ready {
|
||||
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
|
||||
type: boolean
|
||||
@@ -261,7 +266,7 @@
|
||||
type: string
|
||||
}
|
||||
uri {
|
||||
description: "URI for the model"
|
||||
description: "URI for the model. Exactly one of uri or override_model_id is a required."
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
@@ -283,7 +288,7 @@
|
||||
items {type: string}
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task."
|
||||
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
||||
type: string
|
||||
}
|
||||
iteration {
|
||||
|
||||
43
server/schema/services/organization.conf
Normal file
43
server/schema/services/organization.conf
Normal file
@@ -0,0 +1,43 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -69,6 +69,17 @@ info {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.8": ${info."2.1"} {
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
uid {
|
||||
description: "Server UID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
endpoints {
|
||||
"2.1" {
|
||||
|
||||
@@ -254,6 +254,15 @@ _definitions {
|
||||
enum: [
|
||||
training
|
||||
testing
|
||||
inference
|
||||
data_processing
|
||||
application
|
||||
monitor
|
||||
controller
|
||||
optimizer
|
||||
service
|
||||
qc
|
||||
custom
|
||||
]
|
||||
}
|
||||
last_metrics_event {
|
||||
@@ -475,7 +484,11 @@ get_all {
|
||||
minimum: 1
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
description: """List of field names to order by. When search_text is used,
|
||||
'@text_score' can be used as a field representing the text score of returned documents.
|
||||
Use '-' prefix to specify descending order. Optional, recommended when using page.
|
||||
If the first order field is a hyper parameter or metric then string values are ordered
|
||||
according to numeric ordering rules where applicable"""
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
@@ -550,6 +563,31 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_types {
|
||||
"2.8" {
|
||||
description: "Get the list of task types used in the specified projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
types {
|
||||
description: "Unique list of the task types used in the requested projects"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
clone {
|
||||
"2.5" {
|
||||
description: "Clone an existing task"
|
||||
@@ -591,6 +629,10 @@ clone {
|
||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
validate_references {
|
||||
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -901,6 +943,11 @@ reset {
|
||||
properties.force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is 'completed'"
|
||||
}
|
||||
properties.clear_all {
|
||||
description: "Clear script and execution sections completely"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
|
||||
@@ -145,6 +145,19 @@ get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
}
|
||||
"2.8": ${get_all."2.1"} {
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
active_in_projects {
|
||||
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
get_all {
|
||||
|
||||
@@ -10,7 +10,8 @@ import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from init_data import init_es_data, init_mongo_data
|
||||
from elastic.initialize import init_es_data
|
||||
from mongo.initialize import init_mongo_data
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
|
||||
@@ -9,6 +9,7 @@ import jsonmodels.models
|
||||
import timing_context
|
||||
from apierrors import APIError
|
||||
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
|
||||
from api_version import __version__ as _api_version_
|
||||
from config import config
|
||||
from service_repo.base import PartialVersion
|
||||
from .apicall import APICall
|
||||
@@ -34,7 +35,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.6")
|
||||
_max_version = PartialVersion(".".join(_api_version_.split(".")[:2]))
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
@@ -166,7 +167,7 @@ class ServiceRepo(object):
|
||||
return
|
||||
|
||||
assert isinstance(endpoint, Endpoint)
|
||||
call.actual_endpoint_version: PartialVersion = endpoint.min_version
|
||||
call.actual_endpoint_version = endpoint.min_version
|
||||
call.requires_authorization = endpoint.authorize
|
||||
return endpoint
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def validate_all(call: APICall, endpoint: Endpoint):
|
||||
|
||||
def validate_role(endpoint, call):
|
||||
try:
|
||||
if not endpoint.allows(call.identity.role):
|
||||
if endpoint.authorize and not endpoint.allows(call.identity.role):
|
||||
raise errors.forbidden.RoleNotAllowed(role=call.identity.role, allowed=endpoint.allow_roles)
|
||||
except MissingIdentity:
|
||||
pass
|
||||
|
||||
@@ -2,12 +2,16 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
|
||||
import six
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.events import (
|
||||
MultiTaskScalarMetricsIterHistogramRequest,
|
||||
ScalarMetricsIterHistogramRequest,
|
||||
DebugImagesRequest,
|
||||
DebugImageResponse,
|
||||
MetricEvents,
|
||||
IterationEvents,
|
||||
TaskMetricsRequest,
|
||||
LogEventsRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
@@ -23,10 +27,10 @@ event_bll = EventBLL()
|
||||
def add(call: APICall, company_id, req_model):
|
||||
data = call.data.copy()
|
||||
allow_locked = data.pop("allow_locked", False)
|
||||
added, batch_errors = event_bll.add_events(
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id, [data], call.worker, allow_locked_tasks=allow_locked
|
||||
)
|
||||
call.result.data = dict(added=added, errors=len(batch_errors))
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = 1
|
||||
|
||||
|
||||
@@ -36,13 +40,13 @@ def add_batch(call: APICall, company_id, req_model):
|
||||
if events is None or len(events) == 0:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(added=added, errors=len(batch_errors))
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
def get_task_log(call, company_id, req_model):
|
||||
def get_task_log_v1_5(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
order = call.data.get("order") or "desc"
|
||||
@@ -90,6 +94,29 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
# uncomment this once the front end is ready
|
||||
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
||||
# def get_task_log(call, company_id, req_model: LogEventsRequest):
|
||||
# task_id = req_model.task
|
||||
# task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
#
|
||||
# res = event_bll.log_events_iterator.get_task_events(
|
||||
# company_id=company_id,
|
||||
# task_id=task_id,
|
||||
# batch_size=req_model.batch_size,
|
||||
# navigate_earlier=req_model.navigate_earlier,
|
||||
# refresh=req_model.refresh,
|
||||
# state_id=req_model.scroll_id,
|
||||
# )
|
||||
#
|
||||
# call.result.data = dict(
|
||||
# events=res.events,
|
||||
# returned=len(res.events),
|
||||
# total=res.total_events,
|
||||
# scroll_id=res.next_scroll_id,
|
||||
# )
|
||||
|
||||
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
def download_task_log(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
@@ -299,7 +326,7 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task_ids = req_model.tasks
|
||||
if isinstance(task_ids, six.string_types):
|
||||
if isinstance(task_ids, str):
|
||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||
# Note, bll already validates task ids as it needs their names
|
||||
call.result.data = dict(
|
||||
@@ -481,7 +508,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
||||
def get_debug_images(call, company_id, req_model):
|
||||
def get_debug_images_v1_8(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
@@ -507,6 +534,53 @@ def get_debug_images(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.debug_images",
|
||||
min_version="2.7",
|
||||
request_data_model=DebugImagesRequest,
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
|
||||
tasks = set(m.task for m in req_model.metrics)
|
||||
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company_id,
|
||||
metrics=[(m.task, m.metric) for m in req_model.metrics],
|
||||
iter_count=req_model.iters,
|
||||
navigate_earlier=req_model.navigate_earlier,
|
||||
refresh=req_model.refresh,
|
||||
state_id=req_model.scroll_id,
|
||||
)
|
||||
|
||||
call.result.data_model = DebugImageResponse(
|
||||
scroll_id=result.next_scroll_id,
|
||||
metrics=[
|
||||
MetricEvents(
|
||||
task=task,
|
||||
metric=metric,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, metric, iterations) in result.metric_events
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
|
||||
task_bll.assert_exists(
|
||||
call.identity.company, task_ids=req_model.tasks, allow_public=True
|
||||
)
|
||||
res = event_bll.metrics.get_tasks_metrics(
|
||||
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
|
||||
)
|
||||
call.result.data = {
|
||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
|
||||
@@ -12,6 +12,7 @@ from apimodels.models import (
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
)
|
||||
from bll.organization import OrgBLL
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
@@ -29,51 +30,34 @@ from services.utils import conform_tag_fields, conform_output_tags
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
get_all_query_options = Model.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
models = Model.get_many(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
conform_output_tags(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
def get_by_task_id(call: APICall, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=call.identity.company)
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["output"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
@@ -84,13 +68,11 @@ def get_by_task_id(call):
|
||||
|
||||
model_id = task.output.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
@@ -98,31 +80,27 @@ def get_by_task_id(call):
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall):
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
query_options=get_all_query_options,
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
models = Model.get_many(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
query_options=get_all_query_options,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
@@ -146,13 +124,18 @@ create_fields = {
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
return fields
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call, company_id, _):
|
||||
assert isinstance(call, APICall)
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
@@ -195,7 +178,9 @@ def update_for_task(call, company_id, _):
|
||||
|
||||
if task.output and task.output.model:
|
||||
# model exists, update
|
||||
res = _update_model(call, model_id=task.output.model).to_struct()
|
||||
res = _update_model(
|
||||
call, company_id, model_id=task.output.model
|
||||
).to_struct()
|
||||
res.update({"id": task.output.model, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
@@ -218,6 +203,7 @@ def update_for_task(call, company_id, _):
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
@@ -234,48 +220,46 @@ def update_for_task(call, company_id, _):
|
||||
request_data_model=CreateModelRequest,
|
||||
response_data_model=CreateModelResponse,
|
||||
)
|
||||
def create(call, company, req_model):
|
||||
assert isinstance(call, APICall)
|
||||
assert isinstance(req_model, CreateModelRequest)
|
||||
identity = call.identity
|
||||
def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
|
||||
if req_model.public:
|
||||
company = ""
|
||||
company_id = ""
|
||||
|
||||
with translate_errors_context():
|
||||
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company, project=project)
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(call, req_data)
|
||||
validate_task(company_id, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
|
||||
# create and save model
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=identity.user,
|
||||
company=company,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=datetime.utcnow(),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
|
||||
def prepare_update_fields(call, fields):
|
||||
def prepare_update_fields(call, company_id, fields: dict):
|
||||
fields = fields.copy()
|
||||
if "uri" in fields:
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(call, fields)
|
||||
validate_task(company_id, fields)
|
||||
|
||||
if "labels" in fields:
|
||||
labels = fields["labels"]
|
||||
@@ -290,33 +274,36 @@ def prepare_update_fields(call, fields):
|
||||
|
||||
invalid_keys = find_other_types(labels.keys(), str)
|
||||
if invalid_keys:
|
||||
raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"labels keys must be strings", keys=invalid_keys
|
||||
)
|
||||
|
||||
invalid_values = find_other_types(labels.values(), int)
|
||||
if invalid_values:
|
||||
raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"labels values must be integers", values=invalid_values
|
||||
)
|
||||
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
return fields
|
||||
|
||||
|
||||
def validate_task(call, fields):
|
||||
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
|
||||
def validate_task(company_id, fields: dict):
|
||||
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
def edit(call: APICall):
|
||||
identity = call.identity
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
@@ -331,47 +318,44 @@ def edit(call: APICall):
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get('task')
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=identity.company,
|
||||
last_iteration_max=iteration,
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
_update_org_tags(company_id, fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, model_id=None):
|
||||
identity = call.identity
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
# get model by id
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
data = prepare_update_fields(call, call.data)
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=identity.company,
|
||||
last_iteration_max=iteration,
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(
|
||||
call.identity.company, model.id, data
|
||||
)
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
@@ -379,8 +363,8 @@ def _update_model(call: APICall, model_id=None):
|
||||
@endpoint(
|
||||
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call):
|
||||
call.result.data_model = _update_model(call)
|
||||
def update(call, company_id, _):
|
||||
call.result.data_model = _update_model(call, company_id)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -388,31 +372,29 @@ def update(call):
|
||||
request_data_model=PublishModelRequest,
|
||||
response_data_model=PublishModelResponse,
|
||||
)
|
||||
def set_ready(call: APICall, company, req_model: PublishModelRequest):
|
||||
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
|
||||
updated, published_task_data = TaskBLL.model_set_ready(
|
||||
model_id=req_model.model,
|
||||
company_id=company,
|
||||
company_id=company_id,
|
||||
publish_task=req_model.publish_task,
|
||||
force_publish_task=req_model.force_publish_task
|
||||
force_publish_task=req_model.force_publish_task,
|
||||
)
|
||||
|
||||
call.result.data_model = PublishModelResponse(
|
||||
updated=updated,
|
||||
published_task=ModelTaskPublishResponse(
|
||||
**published_task_data
|
||||
) if published_task_data else None
|
||||
published_task=ModelTaskPublishResponse(**published_task_data)
|
||||
if published_task_data
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.delete", required_fields=["model"])
|
||||
def update(call):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
def update(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
force = call.data.get("force", False)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
@@ -445,4 +427,6 @@ def update(call):
|
||||
)
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
if del_count:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
|
||||
13
server/services/organization.py
Normal file
13
server/services/organization.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from apimodels.organization import TagsRequest
|
||||
from bll.organization import OrgBLL
|
||||
from service_repo import endpoint, APICall
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
@endpoint("organization.get_tags", request_data_model=TagsRequest)
|
||||
def get_tags(call: APICall, company, request: TagsRequest):
|
||||
filter_ = request.filter.system_tags if request.filter else None
|
||||
call.result.data = org_bll.get_tags(
|
||||
company, include_system=request.include_system, filter_=filter_
|
||||
)
|
||||
@@ -33,8 +33,7 @@ create_fields = {
|
||||
}
|
||||
|
||||
get_all_query_options = Project.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
|
||||
@@ -58,10 +57,10 @@ def get_by_id(call):
|
||||
call.result.data = {"project": project_dict}
|
||||
|
||||
|
||||
def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def ensure_system_tags():
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
@@ -73,14 +72,20 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
"then": [],
|
||||
"else": "$system_tags",
|
||||
}
|
||||
}
|
||||
},
|
||||
"status": {"$ifNull": ["$status", "unknown"]},
|
||||
}
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{"$match": {"project": {"$in": project_ids}}},
|
||||
ensure_system_tags(),
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
@@ -149,11 +154,12 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"type": {"$in": ["training", "testing"]},
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_system_tags(),
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
@@ -192,7 +198,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
ids = [project["id"] for project in projects]
|
||||
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
||||
ids, specific_state=specific_state
|
||||
call.identity.company, ids, specific_state=specific_state
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
@@ -202,7 +208,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(*status_count_pipeline):
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
@@ -216,7 +222,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(*runtime_pipeline)
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def safe_get(obj, path, default=None):
|
||||
@@ -268,7 +274,7 @@ def create(call):
|
||||
|
||||
with translate_errors_context():
|
||||
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
@@ -305,7 +311,7 @@ def update(call: APICall):
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
@@ -58,7 +58,9 @@ def get_all(call: APICall):
|
||||
|
||||
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
|
||||
def create(call: APICall, company_id, request: CreateRequest):
|
||||
tags, system_tags = conform_tags(call, request.tags, request.system_tags)
|
||||
tags, system_tags = conform_tags(
|
||||
call, request.tags, request.system_tags, validate=True
|
||||
)
|
||||
queue = queue_bll.create(
|
||||
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
|
||||
)
|
||||
@@ -73,7 +75,7 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
)
|
||||
def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
data = call.data_model_for_partial_update
|
||||
conform_tag_fields(call, data)
|
||||
conform_tag_fields(call, data, validate=True)
|
||||
updated, fields = queue_bll.update(
|
||||
company_id=company_id, queue_id=req_model.queue, **data
|
||||
)
|
||||
@@ -212,7 +214,9 @@ def get_queue_metrics(
|
||||
dates=data["date"],
|
||||
avg_waiting_times=data["avg_waiting_time"],
|
||||
queue_lengths=data["queue_length"],
|
||||
) if data else QueueMetrics(queue=queue)
|
||||
)
|
||||
if data
|
||||
else QueueMetrics(queue=queue)
|
||||
for queue, data in queue_dicts.items()
|
||||
]
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from config.info import get_version, get_build_number, get_commit_number
|
||||
from database.errors import translate_errors_context
|
||||
from database.model import Company
|
||||
from database.model.company import ReportStatsOption
|
||||
from database.model.settings import Settings, SettingKeys
|
||||
from service_repo import ServiceRepo, APICall, endpoint
|
||||
|
||||
|
||||
@@ -60,6 +61,12 @@ def info(call: APICall):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("server.info", min_version="2.8")
|
||||
def info_2_8(call: APICall):
|
||||
info(call)
|
||||
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"server.report_stats_option",
|
||||
request_data_model=ReportStatsOptionRequest,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Callable, Type, TypeVar, Union
|
||||
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
@@ -29,8 +29,11 @@ from apimodels.tasks import (
|
||||
CloneRequest,
|
||||
AddOrUpdateArtifactsRequest,
|
||||
AddOrUpdateArtifactsResponse,
|
||||
GetTypesRequest,
|
||||
ResetRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.organization import OrgBLL
|
||||
from bll.queue import QueueBLL
|
||||
from bll.task import (
|
||||
TaskBLL,
|
||||
@@ -39,6 +42,7 @@ from bll.task import (
|
||||
split_by,
|
||||
ParameterKeyEscaper,
|
||||
)
|
||||
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
||||
from bll.util import SetFieldsResolver
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@@ -58,19 +62,13 @@ from utilities import safe_get
|
||||
|
||||
task_fields = set(Task.get_fields())
|
||||
task_script_fields = set(get_fields(Script))
|
||||
get_all_query_options = Task.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
queue_bll = QueueBLL()
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
TaskBLL.start_non_responsive_tasks_watchdog()
|
||||
NonResponsiveTasksWatchdog.start()
|
||||
|
||||
|
||||
def set_task_status_from_call(
|
||||
@@ -110,12 +108,18 @@ def escape_execution_parameters(call: APICall):
|
||||
default_prefix = "execution.parameters."
|
||||
|
||||
def escape_paths(paths, prefix=default_prefix):
|
||||
return [
|
||||
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
|
||||
if path.startswith(prefix)
|
||||
else path
|
||||
for path in paths
|
||||
]
|
||||
escaped_paths = []
|
||||
for path in paths:
|
||||
if path == prefix:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"invalid task field", path=path
|
||||
)
|
||||
escaped_paths.append(
|
||||
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
|
||||
if path.startswith(prefix)
|
||||
else path
|
||||
)
|
||||
return escaped_paths
|
||||
|
||||
projection = Task.get_projection(call.data)
|
||||
if projection:
|
||||
@@ -128,7 +132,7 @@ def escape_execution_parameters(call: APICall):
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall):
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
@@ -136,9 +140,8 @@ def get_all_ex(call: APICall):
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
@@ -146,7 +149,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
@@ -154,16 +157,22 @@ def get_all(call: APICall):
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
def get_types(call: APICall, company_id, request: GetTypesRequest):
|
||||
call.result.data = {
|
||||
"types": list(task_bll.get_types(company_id, project_ids=request.projects))
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
||||
)
|
||||
@@ -256,7 +265,7 @@ create_fields = {
|
||||
|
||||
|
||||
def prepare_for_save(call: APICall, fields: dict):
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_fields:
|
||||
@@ -316,7 +325,7 @@ def prepare_create_fields(
|
||||
return prepare_for_save(call, fields)
|
||||
|
||||
|
||||
def _validate_and_get_task_from_call(call: APICall, **kwargs):
|
||||
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
|
||||
with translate_errors_context(
|
||||
field_does_not_exist_cls=errors.bad_request.ValidationError
|
||||
), TimingContext("code", "parse_call"):
|
||||
@@ -326,7 +335,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
|
||||
with TimingContext("code", "validate"):
|
||||
task_bll.validate(task)
|
||||
|
||||
return task
|
||||
return task, fields
|
||||
|
||||
|
||||
@endpoint("tasks.validate", request_data_model=CreateRequest)
|
||||
@@ -334,14 +343,21 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
|
||||
)
|
||||
def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
task = _validate_and_get_task_from_call(call)
|
||||
task, fields = _validate_and_get_task_from_call(call)
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
||||
task.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
update_project_time(task.project)
|
||||
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
@@ -362,6 +378,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
tags=request.new_task_tags,
|
||||
system_tags=request.new_task_system_tags,
|
||||
execution_overrides=request.execution_overrides,
|
||||
validate_references=request.validate_references,
|
||||
)
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
|
||||
@@ -398,8 +415,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(last_update=datetime.utcnow()),
|
||||
)
|
||||
|
||||
update_project_time(updated_fields.get("project"))
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
update_project_time(updated_fields.get("project"))
|
||||
unprepare_from_saved(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
@@ -431,9 +449,7 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
|
||||
|
||||
|
||||
@endpoint("tasks.update_batch")
|
||||
def update_batch(call: APICall):
|
||||
identity = call.identity
|
||||
|
||||
def update_batch(call: APICall, company_id, _):
|
||||
items = call.batched_data
|
||||
if items is None:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
@@ -443,7 +459,7 @@ def update_batch(call: APICall):
|
||||
tasks = {
|
||||
t.id: t
|
||||
for t in Task.get_many_for_writing(
|
||||
company=identity.company, query=Q(id__in=list(items))
|
||||
company=company_id, query=Q(id__in=list(items))
|
||||
)
|
||||
}
|
||||
|
||||
@@ -461,7 +477,7 @@ def update_batch(call: APICall):
|
||||
continue
|
||||
partial_update_dict.update(last_update=now)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": identity.company}, {"$set": partial_update_dict}
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
)
|
||||
bulk_ops.append(update_op)
|
||||
|
||||
@@ -469,7 +485,8 @@ def update_batch(call: APICall):
|
||||
if bulk_ops:
|
||||
res = Task._get_collection().bulk_write(bulk_ops)
|
||||
updated = res.modified_count
|
||||
|
||||
if updated:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
call.result.data = {"updated": updated}
|
||||
|
||||
|
||||
@@ -524,7 +541,9 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
fields.update(last_update=now)
|
||||
fixed_fields.update(last_update=now)
|
||||
updated = task.update(upsert=False, **fixed_fields)
|
||||
update_project_time(fields.get("project"))
|
||||
if updated:
|
||||
_update_org_tags(company_id, fixed_fields)
|
||||
update_project_time(fields.get("project"))
|
||||
unprepare_from_saved(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
@@ -651,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
|
||||
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
||||
)
|
||||
def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
def reset(call: APICall, company_id, request: ResetRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
req_model.task, company_id=company_id, requires_write_access=True
|
||||
request.task, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
force = req_model.force
|
||||
force = request.force
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||
@@ -674,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
else:
|
||||
if dequeued:
|
||||
api_results.update(dequeued=dequeued)
|
||||
updates.update(unset__execution__queue=1)
|
||||
|
||||
cleaned_up = cleanup_task(task, force)
|
||||
api_results.update(attr.asdict(cleaned_up))
|
||||
@@ -682,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__metric_stats={},
|
||||
unset__output__result=1,
|
||||
unset__output__model=1,
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
)
|
||||
|
||||
if request.clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(),
|
||||
unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
updates.update(
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
)
|
||||
|
||||
res = ResetResponse(
|
||||
**ChangeStatusRequest(
|
||||
task=task,
|
||||
@@ -750,8 +782,7 @@ class CleanupResult(object):
|
||||
deleted_models = attr.ib(type=int)
|
||||
|
||||
|
||||
def cleanup_task(task, force=False):
|
||||
# type: (Task, bool) -> CleanupResult
|
||||
def cleanup_task(task: Task, force: bool = False):
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
@@ -809,6 +840,15 @@ def get_outputs_for_deletion(task, force=False):
|
||||
else:
|
||||
models.draft.append(output_model)
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = [m.id for m in models.draft]
|
||||
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
||||
"id", "execution.model"
|
||||
)
|
||||
busy_models = [t.execution.model for t in dependent_tasks]
|
||||
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
|
||||
|
||||
with TimingContext("mongo", "get_task_children"):
|
||||
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
|
||||
published_tasks = [
|
||||
@@ -869,7 +909,7 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
task.switch_collection(collection_name)
|
||||
|
||||
task.delete()
|
||||
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
call.result.data = dict(deleted=True, **attr.asdict(result))
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import dpath
|
||||
from boltons.iterutils import remap
|
||||
@@ -8,6 +8,7 @@ from mongoengine import Q
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.users import CreateRequest, SetPreferencesRequest
|
||||
from bll.project import ProjectBLL
|
||||
from bll.user import UserBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
@@ -19,10 +20,10 @@ from service_repo import APICall, endpoint
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
log = config.logger(__file__)
|
||||
get_all_query_options = User.QueryParameterOptions(list_fields=("id",))
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
def get_user(call, user_id, only=None):
|
||||
def get_user(call, company_id, user_id, only=None):
|
||||
"""
|
||||
Get user object by the user's ID
|
||||
:param call: API call
|
||||
@@ -34,7 +35,7 @@ def get_user(call, user_id, only=None):
|
||||
# allow system users to get info for all users
|
||||
query = dict(id=user_id)
|
||||
else:
|
||||
query = dict(id=user_id, company=call.identity.company)
|
||||
query = dict(id=user_id, company=company_id)
|
||||
|
||||
with translate_errors_context("retrieving user"):
|
||||
user = User.objects(**query)
|
||||
@@ -48,47 +49,53 @@ def get_user(call, user_id, only=None):
|
||||
|
||||
|
||||
@endpoint("users.get_by_id", required_fields=["user"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
user_id = call.data["user"]
|
||||
call.result.data = {"user": get_user(call, user_id)}
|
||||
call.result.data = {"user": get_user(call, company_id, user_id)}
|
||||
|
||||
|
||||
@endpoint("users.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
res = User.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
)
|
||||
res = User.get_many_with_join(company=company_id, query_dict=call.data)
|
||||
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
|
||||
def get_all_ex2_8(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
data = call.data
|
||||
active_in_projects = call.data.get("active_in_projects", None)
|
||||
if active_in_projects is not None:
|
||||
active_users = project_bll.get_active_users(
|
||||
company_id, active_in_projects, call.data.get("id")
|
||||
)
|
||||
active_users.discard(None)
|
||||
if not active_users:
|
||||
call.result.data = {"users": []}
|
||||
return
|
||||
data = data.copy()
|
||||
data["id"] = list(active_users)
|
||||
|
||||
res = User.get_many_with_join(company=company_id, query_dict=data)
|
||||
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_all", required_fields=[])
|
||||
def get_all(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
def get_all(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
res = User.get_many(
|
||||
company=call.identity.company,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
company=company_id, parameters=call.data, query_dict=call.data
|
||||
)
|
||||
|
||||
call.result.data = {"users": res}
|
||||
|
||||
|
||||
@endpoint("users.get_current_user")
|
||||
def get_current_user(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
def get_current_user(call: APICall, company_id, _):
|
||||
with translate_errors_context("retrieving users"):
|
||||
|
||||
projection = (
|
||||
{"company.name"}
|
||||
.union(User.get_fields())
|
||||
@@ -96,7 +103,7 @@ def get_current_user(call):
|
||||
)
|
||||
res = User.get_many_with_join(
|
||||
query=Q(id=call.identity.user),
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
override_projection=projection,
|
||||
)
|
||||
|
||||
@@ -126,13 +133,11 @@ def create(call: APICall):
|
||||
|
||||
|
||||
@endpoint("users.delete", required_fields=["user"])
|
||||
def delete(call):
|
||||
assert isinstance(call, APICall)
|
||||
def delete(call: APICall):
|
||||
UserBLL.delete(call.data["user"])
|
||||
|
||||
|
||||
def update_user(user_id, company_id, data):
|
||||
# type: (str, str, Dict) -> Tuple[int, Dict]
|
||||
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
|
||||
"""
|
||||
Update user.
|
||||
:param user_id: user ID to update
|
||||
@@ -150,31 +155,29 @@ def update_user(user_id, company_id, data):
|
||||
|
||||
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
|
||||
def update(call, company_id, _):
|
||||
assert isinstance(call, APICall)
|
||||
user_id = call.data["user"]
|
||||
update_count, updated_fields = update_user(user_id, company_id, call.data)
|
||||
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
|
||||
|
||||
|
||||
def get_user_preferences(call):
|
||||
def get_user_preferences(call: APICall, company_id):
|
||||
user_id = call.identity.user
|
||||
preferences = get_user(call, user_id, ["preferences"]).get("preferences")
|
||||
preferences = get_user(call, company_id, user_id, only=["preferences"]).get(
|
||||
"preferences"
|
||||
)
|
||||
if preferences and isinstance(preferences, str):
|
||||
preferences = loads(preferences)
|
||||
return preferences or {}
|
||||
|
||||
|
||||
@endpoint("users.get_preferences")
|
||||
def get_preferences(call):
|
||||
assert isinstance(call, APICall)
|
||||
return {"preferences": get_user_preferences(call)}
|
||||
def get_preferences(call: APICall, company_id, _):
|
||||
return {"preferences": get_user_preferences(call, company_id)}
|
||||
|
||||
|
||||
@endpoint("users.set_preferences", request_data_model=SetPreferencesRequest)
|
||||
def set_preferences(call, company_id, req_model):
|
||||
# type: (APICall, str, SetPreferencesRequest) -> Dict
|
||||
assert isinstance(call, APICall)
|
||||
changes = req_model.preferences
|
||||
def set_preferences(call: APICall, company_id, request: SetPreferencesRequest):
|
||||
changes = request.preferences
|
||||
|
||||
def invalid_key(_, key, __):
|
||||
if not isinstance(key, str):
|
||||
@@ -187,7 +190,7 @@ def set_preferences(call, company_id, req_model):
|
||||
|
||||
remap(changes, visit=invalid_key)
|
||||
|
||||
base_preferences = get_user_preferences(call)
|
||||
base_preferences = get_user_preferences(call, company_id)
|
||||
new_preferences = deepcopy(base_preferences)
|
||||
for key, value in changes.items():
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import GetMixin
|
||||
from database.utils import partition_tags
|
||||
from service_repo import APICall
|
||||
from service_repo.base import PartialVersion
|
||||
@@ -19,13 +21,13 @@ def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
|
||||
|
||||
|
||||
def conform_tag_fields(call: APICall, document: dict):
|
||||
def conform_tag_fields(call: APICall, document: dict, validate=False):
|
||||
"""
|
||||
Upgrade old client tags in place
|
||||
"""
|
||||
if "tags" in document:
|
||||
tags, system_tags = conform_tags(
|
||||
call, document["tags"], document.get("system_tags")
|
||||
call, document["tags"], document.get("system_tags"), validate
|
||||
)
|
||||
if tags != document.get("tags"):
|
||||
document["tags"] = tags
|
||||
@@ -34,16 +36,18 @@ def conform_tag_fields(call: APICall, document: dict):
|
||||
|
||||
|
||||
def conform_tags(
|
||||
call: APICall, tags: Sequence, system_tags: Sequence
|
||||
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
"""
|
||||
Make sure that 'tags' from the old SDK clients
|
||||
are correctly split into 'tags' and 'system_tags'
|
||||
Make sure that there are no duplicate tags
|
||||
"""
|
||||
if validate:
|
||||
validate_tags(tags, system_tags)
|
||||
if call.requested_endpoint_version < PartialVersion("2.3"):
|
||||
tags, system_tags = _upgrade_tags(call, tags, system_tags)
|
||||
return _get_unique_values(tags), _get_unique_values(system_tags)
|
||||
return tags, system_tags
|
||||
|
||||
|
||||
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
|
||||
@@ -55,9 +59,12 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
|
||||
return tags, system_tags
|
||||
|
||||
|
||||
def _get_unique_values(values: Sequence) -> Sequence:
|
||||
"""Get unique values from the given sequence"""
|
||||
if not values:
|
||||
return values
|
||||
|
||||
return list(set(values))
|
||||
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
||||
for values in filter(None, (tags, system_tags)):
|
||||
unsupported = [
|
||||
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
|
||||
]
|
||||
if unsupported:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported tag prefix", values=unsupported
|
||||
)
|
||||
|
||||
@@ -54,6 +54,10 @@ class TestService(TestCase, TestServiceInterface):
|
||||
)
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def update_missing(target: dict, **update):
|
||||
target.update({k: v for k, v in update.items() if k not in target})
|
||||
|
||||
def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str:
|
||||
return self._create_temp_helper(
|
||||
service=service,
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import operator
|
||||
from time import sleep
|
||||
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestEntityOrdering(TestService):
|
||||
test_comment = "Entity ordering test"
|
||||
only_fields = ["id", "started", "comment"]
|
||||
only_fields = ["id", "started", "comment", "execution.parameters"]
|
||||
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(**kwargs)
|
||||
@@ -27,6 +27,9 @@ class TestEntityOrdering(TestService):
|
||||
# sort by the same field that we use for the search
|
||||
self._assertGetTasksWithOrdering(order_by="comment")
|
||||
|
||||
# sort by parameter which type is not part of db schema
|
||||
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
|
||||
|
||||
def test_order_with_paging(self):
|
||||
order_field = "started"
|
||||
# all results in one page
|
||||
@@ -52,23 +55,33 @@ class TestEntityOrdering(TestService):
|
||||
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
|
||||
return self.api.tasks.get_all_ex(
|
||||
only_fields=self.only_fields,
|
||||
order_by=[order_by] if order_by else None,
|
||||
order_by=[order_by] if isinstance(order_by, str) else order_by,
|
||||
comment=self.test_comment,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
).tasks
|
||||
|
||||
def _assertSorted(self, vals: Sequence, ascending=True):
|
||||
def _assertSorted(self, vals: Sequence, ascending=True, is_numeric=False):
|
||||
"""
|
||||
Assert that vals are sorted in the ascending or descending order
|
||||
with None values are always coming from the end
|
||||
"""
|
||||
if None in vals:
|
||||
first_null_idx = vals.index(None)
|
||||
none_tail = vals[first_null_idx:]
|
||||
vals = vals[:first_null_idx]
|
||||
self.assertTrue(all(val is None for val in none_tail))
|
||||
self.assertTrue(all(val is not None for val in vals))
|
||||
empty = [None, "", [], {}]
|
||||
empty_value = None
|
||||
idx = 0
|
||||
for idx, val in enumerate(vals):
|
||||
if val in empty:
|
||||
empty_value = val
|
||||
break
|
||||
|
||||
if idx < len(vals) - 1:
|
||||
none_tail = vals[idx:]
|
||||
vals = vals[:idx]
|
||||
self.assertTrue(all(val == empty_value for val in none_tail))
|
||||
self.assertTrue(all(val != empty_value for val in vals))
|
||||
|
||||
if is_numeric:
|
||||
vals = list(map(int, vals))
|
||||
|
||||
if ascending:
|
||||
cmp = operator.le
|
||||
@@ -76,10 +89,18 @@ class TestEntityOrdering(TestService):
|
||||
cmp = operator.ge
|
||||
self.assertTrue(all(cmp(i, j) for i, j in zip(vals, vals[1:])))
|
||||
|
||||
def _get_value_for_path(self, data: Mapping, field_path: Sequence[str]):
|
||||
val = None
|
||||
for name in field_path:
|
||||
val = data.get(name)
|
||||
data = val if isinstance(val, dict) else {}
|
||||
|
||||
return val
|
||||
|
||||
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
only_fields=self.only_fields,
|
||||
order_by=[order_by] if order_by else None,
|
||||
order_by=[order_by] if isinstance(order_by, str) else order_by,
|
||||
comment=self.test_comment,
|
||||
**kwargs,
|
||||
).tasks
|
||||
@@ -87,12 +108,21 @@ class TestEntityOrdering(TestService):
|
||||
if order_by:
|
||||
# test that the output is correctly ordered
|
||||
field_name = order_by if not order_by.startswith("-") else order_by[1:]
|
||||
field_vals = [t.get(field_name) for t in tasks]
|
||||
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
|
||||
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
|
||||
self._assertSorted(
|
||||
field_vals,
|
||||
ascending=not order_by.startswith("-"),
|
||||
is_numeric=field_name.startswith("execution.parameters.")
|
||||
)
|
||||
|
||||
def _create_tasks(self):
|
||||
tasks = [self._temp_task() for _ in range(10)]
|
||||
for _, task in zip(range(5), tasks):
|
||||
tasks = [
|
||||
self._temp_task(
|
||||
**(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}}))
|
||||
)
|
||||
for i in range(20)
|
||||
]
|
||||
for idx, task in zip(range(5), tasks):
|
||||
self.api.tasks.started(task=task)
|
||||
sleep(0.1)
|
||||
return tasks
|
||||
|
||||
36
server/tests/automated/test_organization.py
Normal file
36
server/tests/automated/test_organization.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_tags(self):
|
||||
tag1 = "Orgtest tag1"
|
||||
tag2 = "Orgtest tag2"
|
||||
system_tag = "Orgtest system tag"
|
||||
|
||||
model = self.create_temp(
|
||||
"models", name="test_org", uri="file:///a", tags=[tag1]
|
||||
)
|
||||
task = self.create_temp(
|
||||
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
|
||||
)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
|
||||
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
|
||||
data = self.api.organization.get_tags(include_system=True)
|
||||
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
|
||||
self.assertTrue(system_tag in data.system_tags)
|
||||
|
||||
data = self.api.organization.get_tags(
|
||||
filter={"system_tags": ["__$not", system_tag]}
|
||||
)
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
self.assertFalse(tag2 in data.tags)
|
||||
|
||||
self.api.models.delete(model=model)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
@@ -208,25 +208,21 @@ class TestTags(TestService):
|
||||
self.api.tasks.stopped(task=task_id)
|
||||
|
||||
def _temp_queue(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags")
|
||||
self.update_missing(kwargs, name="Test tags")
|
||||
return self.create_temp("queues", **kwargs)
|
||||
|
||||
def _temp_project(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags", description="test")
|
||||
self.update_missing(kwargs, name="Test tags", description="test")
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
def _temp_model(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
|
||||
self.update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", **kwargs)
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
|
||||
self.update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _update_missing(target: dict, **update):
|
||||
target.update({k: v for k, v in update.items() if k not in target})
|
||||
|
||||
def _send(self, service, action, **kwargs):
|
||||
api = kwargs.pop("api", self.api)
|
||||
return AttrDict(
|
||||
|
||||
@@ -2,83 +2,261 @@
|
||||
Comprehensive test of all(?) use cases of datasets and frames
|
||||
"""
|
||||
import json
|
||||
import operator
|
||||
import unittest
|
||||
from functools import partial
|
||||
from statistics import mean
|
||||
from typing import Sequence
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from apierrors.errors.bad_request import EventsNotAdded
|
||||
from tests.automated import TestService
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
def setUp(self, version="1.7"):
|
||||
def setUp(self, version="2.7"):
|
||||
super().setUp(version=version)
|
||||
|
||||
self.created_tasks = []
|
||||
|
||||
self.task = dict(
|
||||
name="test task events",
|
||||
type="training",
|
||||
input=dict(mapping={}, view=dict(entries=[])),
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
)
|
||||
res, self.task_id = self.api.send("tasks.create", self.task, extract="id")
|
||||
assert res.meta.result_code == 200
|
||||
self.created_tasks.append(self.task_id)
|
||||
return self.create_temp("tasks", **task_input)
|
||||
|
||||
def tearDown(self):
|
||||
log.info("Cleanup...")
|
||||
for task_id in self.created_tasks:
|
||||
try:
|
||||
self.api.send("tasks.delete", dict(task=task_id, force=True))
|
||||
except Exception as ex:
|
||||
log.exception(ex)
|
||||
|
||||
def create_task_event(self, type, iteration):
|
||||
@staticmethod
|
||||
def _create_task_event(type_, task, iteration, **kwargs):
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type,
|
||||
"task": self.task_id,
|
||||
"type": type_,
|
||||
"task": task,
|
||||
"iter": iteration,
|
||||
"timestamp": es_factory.get_timestamp_millis()
|
||||
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def copy_and_update(self, src_obj, new_data):
|
||||
obj = src_obj.copy()
|
||||
obj.update(new_data)
|
||||
return obj
|
||||
|
||||
def test_task_logs(self):
|
||||
events = []
|
||||
for iter in range(10):
|
||||
log_event = self.create_task_event("log", iteration=iter)
|
||||
events.append(
|
||||
self.copy_and_update(
|
||||
log_event,
|
||||
{"msg": "This is a log message from test task iter " + str(iter)},
|
||||
)
|
||||
def test_task_metrics(self):
|
||||
tasks = {
|
||||
self._temp_task(): {
|
||||
"Metric1": ["training_debug_image"],
|
||||
"Metric2": ["training_debug_image", "log"],
|
||||
},
|
||||
self._temp_task(): {"Metric3": ["training_debug_image"]},
|
||||
}
|
||||
events = [
|
||||
self._create_task_event(
|
||||
event_type,
|
||||
task=task,
|
||||
iteration=1,
|
||||
metric=metric,
|
||||
variant="Test variant",
|
||||
)
|
||||
# sleep so timestamp is not the same
|
||||
import time
|
||||
for task, metrics in tasks.items()
|
||||
for metric, event_types in metrics.items()
|
||||
for event_type in event_types
|
||||
]
|
||||
self.send_batch(events)
|
||||
self._assert_task_metrics(tasks, "training_debug_image")
|
||||
self._assert_task_metrics(tasks, "log")
|
||||
self._assert_task_metrics(tasks, "training_stats_scalar")
|
||||
|
||||
time.sleep(0.01)
|
||||
def _assert_task_metrics(self, tasks: dict, event_type: str):
|
||||
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
|
||||
for task, metrics in tasks.items():
|
||||
res_metrics = next(
|
||||
(tm.metrics for tm in res.metrics if tm.task == task), ()
|
||||
)
|
||||
self.assertEqual(
|
||||
set(res_metrics),
|
||||
set(
|
||||
metric for metric, events in metrics.items() if event_type in events
|
||||
),
|
||||
)
|
||||
|
||||
def test_task_debug_images(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
variants = [("Variant1", 7), ("Variant2", 4)]
|
||||
iterations = 10
|
||||
|
||||
# test empty
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}], iters=5,
|
||||
)
|
||||
self.assertFalse(res.metrics)
|
||||
|
||||
# create events
|
||||
events = [
|
||||
self._create_task_event(
|
||||
"training_debug_image",
|
||||
task=task,
|
||||
iteration=n,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
url=f"{metric}_{variant}_{n % unique_images}",
|
||||
)
|
||||
for n in range(iterations)
|
||||
for (variant, unique_images) in variants
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
assert len(data["events"]) == 10
|
||||
# init testing
|
||||
unique_images = [unique for (_, unique) in variants]
|
||||
scroll_id = None
|
||||
assert_debug_images = partial(
|
||||
self._assertDebugImages,
|
||||
task=task,
|
||||
metric=metric,
|
||||
max_iter=iterations - 1,
|
||||
unique_images=unique_images,
|
||||
)
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
assert len(data["events"]) == 0
|
||||
# test forward navigation
|
||||
for page in range(3):
|
||||
scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page)
|
||||
|
||||
# test backwards navigation
|
||||
scroll_id = assert_debug_images(
|
||||
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
|
||||
)
|
||||
|
||||
# beyond the latest iteration and back
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}],
|
||||
iters=5,
|
||||
scroll_id=scroll_id,
|
||||
navigate_earlier=False,
|
||||
)
|
||||
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
|
||||
assert_debug_images(scroll_id=scroll_id, expected_page=1)
|
||||
|
||||
# refresh
|
||||
assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True)
|
||||
|
||||
def _assertDebugImages(
|
||||
self,
|
||||
task,
|
||||
metric,
|
||||
max_iter: int,
|
||||
unique_images: Sequence[int],
|
||||
scroll_id,
|
||||
expected_page: int,
|
||||
iters: int = 5,
|
||||
**extra_params,
|
||||
):
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}],
|
||||
iters=iters,
|
||||
scroll_id=scroll_id,
|
||||
**extra_params,
|
||||
)
|
||||
data = res["metrics"][0]
|
||||
self.assertEqual(data["task"], task)
|
||||
self.assertEqual(data["metric"], metric)
|
||||
left_iterations = max(0, max(unique_images) - expected_page * iters)
|
||||
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
||||
for it in data["iterations"]:
|
||||
events_per_iter = sum(
|
||||
1 for unique in unique_images if unique > max_iter - it["iter"]
|
||||
)
|
||||
self.assertEqual(len(it["events"]), events_per_iter)
|
||||
return res.scroll_id
|
||||
|
||||
def test_error_events(self):
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
self._create_task_event("unknown type", task, iteration=1),
|
||||
self._create_task_event("training_debug_image", task=None, iteration=1),
|
||||
self._create_task_event(
|
||||
"training_debug_image", task="Invalid task", iteration=1
|
||||
),
|
||||
]
|
||||
# failure if no events added
|
||||
with self.api.raises(EventsNotAdded):
|
||||
self.send_batch(events)
|
||||
|
||||
events.append(
|
||||
self._create_task_event("training_debug_image", task=task, iteration=1)
|
||||
)
|
||||
# success if at least one event added
|
||||
res = self.send_batch(events)
|
||||
self.assertEqual(res["added"], 1)
|
||||
self.assertEqual(res["errors"], 3)
|
||||
self.assertEqual(len(res["errors_info"]), 3)
|
||||
res = self.api.events.get_task_events(task=task)
|
||||
self.assertEqual(len(res.events), 1)
|
||||
|
||||
def test_task_logs(self):
|
||||
# this test will fail until the new api is uncommented
|
||||
task = self._temp_task()
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
events = [
|
||||
self._create_task_event(
|
||||
"log",
|
||||
task=task,
|
||||
iteration=iter_,
|
||||
timestamp=timestamp + iter_ * 1000,
|
||||
msg=f"This is a log message from test task iter {iter_}",
|
||||
)
|
||||
for iter_ in range(10)
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
# test forward navigation
|
||||
scroll_id = None
|
||||
for page in range(3):
|
||||
scroll_id = self._assert_log_events(
|
||||
task=task, scroll_id=scroll_id, expected_page=page
|
||||
)
|
||||
|
||||
# test backwards navigation
|
||||
scroll_id = self._assert_log_events(
|
||||
task=task, scroll_id=scroll_id, navigate_earlier=False
|
||||
)
|
||||
|
||||
# refresh
|
||||
self._assert_log_events(task=task, scroll_id=scroll_id)
|
||||
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
|
||||
|
||||
def _assert_log_events(
|
||||
self,
|
||||
task,
|
||||
scroll_id,
|
||||
batch_size: int = 5,
|
||||
expected_total: int = 10,
|
||||
expected_page: int = 0,
|
||||
**extra_params,
|
||||
):
|
||||
res = self.api.events.get_task_log(
|
||||
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
|
||||
)
|
||||
self.assertEqual(res.total, expected_total)
|
||||
expected_events = max(
|
||||
0, batch_size - max(0, (expected_page + 1) * batch_size - expected_total)
|
||||
)
|
||||
self.assertEqual(res.returned, expected_events)
|
||||
self.assertEqual(len(res.events), expected_events)
|
||||
unique_events = len({ev.iter for ev in res.events})
|
||||
self.assertEqual(len(res.events), unique_events)
|
||||
if res.events:
|
||||
cmp_operator = operator.ge
|
||||
if not extra_params.get("navigate_earlier", True):
|
||||
cmp_operator = operator.le
|
||||
self.assertTrue(
|
||||
all(
|
||||
cmp_operator(first.timestamp, second.timestamp)
|
||||
for first, second in zip(res.events, res.events[1:])
|
||||
)
|
||||
)
|
||||
return res.scroll_id
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self.create_task_event("training_stats_scalar", iteration),
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
@@ -88,19 +266,65 @@ class TestTaskEvents(TestService):
|
||||
self.send_batch(events)
|
||||
for key in None, "iter", "timestamp", "iso_time":
|
||||
with self.subTest(key=key):
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, key=key)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key)
|
||||
self.assertIn(metric, data)
|
||||
self.assertIn(variant, data[metric])
|
||||
self.assertIn("x", data[metric][variant])
|
||||
self.assertIn("y", data[metric][variant])
|
||||
|
||||
def test_multitask_events_many_metrics(self):
|
||||
tasks = [
|
||||
self._temp_task(name="test events1"),
|
||||
self._temp_task(name="test events2"),
|
||||
]
|
||||
iter_count = 10
|
||||
metrics_count = 10
|
||||
variants_count = 10
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": f"Metric{metric_idx}",
|
||||
"variant": f"Variant{variant_idx}",
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
for task in tasks
|
||||
for metric_idx in range(metrics_count)
|
||||
for variant_idx in range(variants_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
data = self.api.events.multi_task_scalar_metrics_iter_histogram(tasks=tasks)
|
||||
self._assert_metrics_and_variants(
|
||||
data.metrics,
|
||||
metrics=metrics_count,
|
||||
variants=variants_count,
|
||||
tasks=tasks,
|
||||
iterations=iter_count,
|
||||
)
|
||||
|
||||
def _assert_metrics_and_variants(
|
||||
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
|
||||
):
|
||||
self.assertEqual(len(data), metrics)
|
||||
for m in range(metrics):
|
||||
metric_data = data[f"Metric{m}"]
|
||||
self.assertEqual(len(metric_data), variants)
|
||||
for v in range(variants):
|
||||
variant_data = metric_data[f"Variant{v}"]
|
||||
self.assertEqual(len(variant_data), len(tasks))
|
||||
for t in tasks:
|
||||
task_data = variant_data[t]
|
||||
self.assertEqual(len(task_data["x"]), iterations)
|
||||
self.assertEqual(len(task_data["y"]), iterations)
|
||||
|
||||
def test_task_metric_value_intervals(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self.create_task_event("training_stats_scalar", iteration),
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
@@ -109,13 +333,13 @@ class TestTaskEvents(TestService):
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=100)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=100)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=10)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=10)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 10)
|
||||
|
||||
def _assert_metrics_histogram(self, data, iters, samples):
|
||||
@@ -130,7 +354,8 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
|
||||
def test_task_plots(self):
|
||||
event = self.create_task_event("plot", 0)
|
||||
task = self._temp_task()
|
||||
event = self._create_task_event("plot", task, 0)
|
||||
event["metric"] = "roc"
|
||||
event.update(
|
||||
{
|
||||
@@ -179,7 +404,7 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
event = self.create_task_event("plot", 100)
|
||||
event = self._create_task_event("plot", task, 100)
|
||||
event["metric"] = "confusion"
|
||||
event.update(
|
||||
{
|
||||
@@ -222,15 +447,16 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
data = self.api.events.get_task_plots(task=task)
|
||||
assert len(data["plots"]) == 2
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
self.api.tasks.reset(task=task)
|
||||
data = self.api.events.get_task_plots(task=task)
|
||||
assert len(data["plots"]) == 0
|
||||
|
||||
def send_batch(self, events):
|
||||
self.api.send_batch("events.add_batch", events)
|
||||
_, data = self.api.send_batch("events.add_batch", events)
|
||||
return data
|
||||
|
||||
def send(self, event):
|
||||
self.api.send("events.add", event)
|
||||
|
||||
@@ -14,9 +14,13 @@ class TestTasksDiff(TestService):
|
||||
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
|
||||
)
|
||||
|
||||
def _compare_script(self, task, script):
|
||||
for key, value in script.items():
|
||||
self.assertEqual(task.script[key], value)
|
||||
def _compare_script(self, task_id, script):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
if not script:
|
||||
self.assertFalse(task.get("script", None))
|
||||
else:
|
||||
for key, value in script.items():
|
||||
self.assertEqual(task.script[key], value)
|
||||
|
||||
def test_not_deleted(self):
|
||||
task_id = self.new_task()
|
||||
@@ -28,11 +32,14 @@ class TestTasksDiff(TestService):
|
||||
)
|
||||
self.api.tasks.edit(task=task_id, script=script)
|
||||
self.api.tasks.started(task=task_id)
|
||||
|
||||
self.api.tasks.reset(task=task_id)
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self._compare_script(task, script)
|
||||
self._compare_script(task_id, script)
|
||||
|
||||
new_reqs = dict()
|
||||
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
|
||||
script["requirements"] = new_reqs
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self._compare_script(task, script)
|
||||
self._compare_script(task_id, script)
|
||||
|
||||
self.api.tasks.reset(task=task_id, clear_all=True)
|
||||
self._compare_script(task_id, {})
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from apierrors.errors.bad_request import InvalidModelId, ValidationError
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
@@ -10,12 +11,37 @@ class TestTasksEdit(TestService):
|
||||
super().setUp(version=2.5)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks", type="testing", name="test", input=dict(view=dict()), **kwargs
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test", input=dict(view=dict())
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def new_model(self):
|
||||
return self.create_temp("models", name="test", uri="file:///a/b", labels={})
|
||||
def new_model(self, **kwargs):
|
||||
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", **kwargs)
|
||||
|
||||
def test_task_types(self):
|
||||
with self.api.raises(ValidationError):
|
||||
task = self.new_task(type="Unsupported")
|
||||
|
||||
types = ["controller", "optimizer"]
|
||||
p1 = self.create_temp("projects", name="Test tasks1", description="test")
|
||||
task1 = self.new_task(project=p1, type=types[0])
|
||||
p2 = self.create_temp("projects", name="Test tasks2", description="test")
|
||||
task2 = self.new_task(project=p2, type=types[1])
|
||||
|
||||
# all company types
|
||||
res = self.api.tasks.get_types()
|
||||
self.assertTrue(set(types).issubset(set(res["types"])))
|
||||
|
||||
# projects array
|
||||
res = self.api.tasks.get_types(projects=[p1, p2])
|
||||
self.assertEqual(set(types), set(res["types"]))
|
||||
|
||||
# single project
|
||||
for p, t in zip((p1, p2), types):
|
||||
res = self.api.tasks.get_types(projects=[p])
|
||||
self.assertEqual([t], res["types"])
|
||||
|
||||
def test_edit_model_ready(self):
|
||||
task = self.new_task()
|
||||
@@ -38,6 +64,23 @@ class TestTasksEdit(TestService):
|
||||
self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready)
|
||||
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model))
|
||||
|
||||
def test_task_with_model_reset(self):
|
||||
# on task reset output model deleted
|
||||
task = self.new_task()
|
||||
self.api.tasks.started(task=task)
|
||||
model_id = self.api.models.update_for_task(task=task, uri="file:///b")["id"]
|
||||
self.api.tasks.reset(task=task)
|
||||
with self.api.raises(InvalidModelId):
|
||||
self.api.models.get_by_id(model=model_id)
|
||||
|
||||
# unless it is input of some task
|
||||
task = self.new_task()
|
||||
self.api.tasks.started(task=task)
|
||||
model_id = self.api.models.update_for_task(task=task, uri="file:///b")["id"]
|
||||
task_2 = self.new_task(execution=dict(model=model_id))
|
||||
self.api.tasks.reset(task=task)
|
||||
self.api.models.get_by_id(model=model_id)
|
||||
|
||||
def test_clone_task(self):
|
||||
script = dict(
|
||||
binary="python",
|
||||
@@ -56,13 +99,13 @@ class TestTasksEdit(TestService):
|
||||
new_name = "new test"
|
||||
new_tags = ["by"]
|
||||
execution_overrides = dict(framework="Caffe")
|
||||
new_task_id = self.api.tasks.clone(
|
||||
new_task_id = self._clone_task(
|
||||
task=task,
|
||||
new_task_name=new_name,
|
||||
new_task_tags=new_tags,
|
||||
execution_overrides=execution_overrides,
|
||||
new_task_parent=task,
|
||||
).id
|
||||
)
|
||||
new_task = self.api.tasks.get_by_id(task=new_task_id).task
|
||||
self.assertEqual(new_task.name, new_name)
|
||||
self.assertEqual(new_task.type, "testing")
|
||||
@@ -73,3 +116,32 @@ class TestTasksEdit(TestService):
|
||||
self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
||||
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
|
||||
self.assertEqual(new_task.system_tags, [])
|
||||
|
||||
def test_model_check_in_clone(self):
|
||||
model = self.new_model()
|
||||
task = self.new_task(execution=dict(model=model))
|
||||
|
||||
# task with deleted model still can be copied
|
||||
self.api.models.delete(model=model, force=True)
|
||||
self._clone_task(task=task, new_task_name="clone test")
|
||||
|
||||
# unless check for refs is done
|
||||
with self.api.raises(InvalidModelId):
|
||||
self._clone_task(
|
||||
task=task, new_task_name="clone test2", validate_references=True
|
||||
)
|
||||
|
||||
# if the model is overriden then it is always checked
|
||||
with self.api.raises(InvalidModelId):
|
||||
self._clone_task(
|
||||
task=task,
|
||||
new_task_name="clone test3",
|
||||
execution_overrides=dict(model="not existing"),
|
||||
)
|
||||
|
||||
def _clone_task(self, task, **kwargs):
|
||||
new_task = self.api.tasks.clone(task=task, **kwargs).id
|
||||
self.defer(
|
||||
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
|
||||
)
|
||||
return new_task
|
||||
|
||||
89
server/tests/automated/test_users.py
Normal file
89
server/tests/automated/test_users.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestUsersService(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super(TestUsersService, self).setUp(version=version)
|
||||
self.company = self.api.users.get_current_user().user.company.id
|
||||
|
||||
def new_user(self):
|
||||
user_name = uuid4().hex
|
||||
user_id = self.api.auth.create_user(
|
||||
company=self.company, name=user_name, email="{0}@{0}.com".format(user_name)
|
||||
).id
|
||||
self.defer(self.api.users.delete, user=user_id)
|
||||
return user_id
|
||||
|
||||
def test_active_users(self):
|
||||
user_1 = self.new_user()
|
||||
user_2 = self.new_user()
|
||||
user_3 = self.new_user()
|
||||
|
||||
model = (
|
||||
self.api.impersonate(user_2)
|
||||
.models.create(name="test", uri="file:///a", labels={})
|
||||
.id
|
||||
)
|
||||
self.defer(self.api.models.delete, model=model)
|
||||
project = self.create_temp("projects", name="users test", description="")
|
||||
task = (
|
||||
self.api.impersonate(user_3)
|
||||
.tasks.create(
|
||||
name="test", type="testing", input=dict(view={}), project=project
|
||||
)
|
||||
.id
|
||||
)
|
||||
self.defer(self.api.tasks.delete, task=task, move_to_trash=False)
|
||||
|
||||
user_ids = [user_1, user_2, user_3]
|
||||
# no projects filtering
|
||||
users = self.api.users.get_all_ex(id=user_ids).users
|
||||
self._assertUsers((user_1, user_2, user_3), users)
|
||||
|
||||
# all projects
|
||||
users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[]).users
|
||||
self._assertUsers((user_2, user_3), users)
|
||||
|
||||
# specific project
|
||||
users = self.api.users.get_all_ex(active_in_projects=[project]).users
|
||||
self._assertUsers((user_3,), users)
|
||||
|
||||
def _assertUsers(self, expected: Sequence, users: Sequence):
|
||||
self.assertEqual(set(expected), set(u.id for u in users))
|
||||
|
||||
def test_no_preferences(self):
|
||||
user = self.new_user()
|
||||
assert self.api.impersonate(user).users.get_preferences().preferences == {}
|
||||
|
||||
def _test_update(self, user, tests):
|
||||
"""
|
||||
Check that all for each (updates, expected_result) pair, ``updates`` yield ``result``.
|
||||
"""
|
||||
new_user_client = self.api.impersonate(user)
|
||||
for update, expected in tests:
|
||||
new_user_client.users.set_preferences(user=user, preferences=update)
|
||||
preferences = new_user_client.users.get_preferences(user=user).preferences
|
||||
self.assertEqual(preferences, expected)
|
||||
|
||||
def test_nested_update(self):
|
||||
tests = [
|
||||
({"a": 0}, {"a": 0}),
|
||||
({"b": 1}, {"a": 0, "b": 1}),
|
||||
({"section": {"a": 2}}, {"a": 0, "b": 1, "section": {"a": 2}}),
|
||||
]
|
||||
self._test_update(self.new_user(), tests)
|
||||
|
||||
def test_delete(self):
|
||||
tests = [
|
||||
({"section": {"a": 0, "b": 1}},) * 2,
|
||||
({"section": {"a": None}}, {"section": {"a": None}}),
|
||||
({"section": None}, {"section": None}),
|
||||
]
|
||||
self._test_update(self.new_user(), tests)
|
||||
10
server/utilities/stringenum.py
Normal file
10
server/utilities/stringenum.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
@@ -10,7 +10,7 @@ class ThreadsManager:
|
||||
|
||||
def __init__(self, name=None, **threads):
|
||||
super(ThreadsManager, self).__init__()
|
||||
self.name = name or self.__class__.name
|
||||
self.name = name or self.__class__.__name__
|
||||
self.objects = {}
|
||||
self.lock = Lock()
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.13.0"
|
||||
__version__ = "0.15.0"
|
||||
|
||||
Reference in New Issue
Block a user