mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
119 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a936a210e8 | ||
|
|
be0cf0caa8 | ||
|
|
a8d90887e2 | ||
|
|
6f3257fed3 | ||
|
|
4bb8834551 | ||
|
|
286b8c3df5 | ||
|
|
16430a6636 | ||
|
|
d7ddfde26e | ||
|
|
e6c0f1b6d8 | ||
|
|
641ed1b510 | ||
|
|
e29ad4c9b2 | ||
|
|
3473d2bb02 | ||
|
|
ba03924cb4 | ||
|
|
6870d8aba9 | ||
|
|
64c63d2560 | ||
|
|
88836fae66 | ||
|
|
436883148b | ||
|
|
f9f2f0ccf0 | ||
|
|
f879f6924f | ||
|
|
b9cb587580 | ||
|
|
370e92c3dd | ||
|
|
03094076c8 | ||
|
|
bdf6c353bd | ||
|
|
23736efbc3 | ||
|
|
3c8e27dc94 | ||
|
|
ca890c7ae8 | ||
|
|
30909df73f | ||
|
|
b97a6084ce | ||
|
|
50438bd931 | ||
|
|
28daf49c91 | ||
|
|
4707647c92 | ||
|
|
6974aa3a99 | ||
|
|
e2deff4eef | ||
|
|
59994ccf9c | ||
|
|
29c792d459 | ||
|
|
df334d083e | ||
|
|
b548958c80 | ||
|
|
7bdf8fe30d | ||
|
|
c71c65be87 | ||
|
|
1cc6a8f787 | ||
|
|
e5b92f4a80 | ||
|
|
3272d0f31f | ||
|
|
618a0b9473 | ||
|
|
bca3a6e556 | ||
|
|
8b0afd47a6 | ||
|
|
0303c3525f | ||
|
|
563c451ac9 | ||
|
|
91b1b34a6b | ||
|
|
0ad0495733 | ||
|
|
03ae90c4a6 | ||
|
|
be788965e0 | ||
|
|
d198138c5b | ||
|
|
cf441987af | ||
|
|
b89de43373 | ||
|
|
0ef018c931 | ||
|
|
323b5db07c | ||
|
|
f084f6b9e7 | ||
|
|
eb4c9f0b13 | ||
|
|
018582ff8a | ||
|
|
7dcc0f6df2 | ||
|
|
5e0893dd80 | ||
|
|
ca81922651 | ||
|
|
07cc2fb08b | ||
|
|
842654d3fe | ||
|
|
00e5e2a0b1 | ||
|
|
37e5d8a7e0 | ||
|
|
5b1f468957 | ||
|
|
9103bf7984 | ||
|
|
e848d05677 | ||
|
|
1c7de3a86e | ||
|
|
e12fd8f3df | ||
|
|
29ef134b79 | ||
|
|
e24389fda9 | ||
|
|
f4ead86449 | ||
|
|
171969c5ea | ||
|
|
89f81bfe5a | ||
|
|
b8e62f27e2 | ||
|
|
c7bbac73d0 | ||
|
|
f832ea565a | ||
|
|
22e9c2b7eb | ||
|
|
c67a56eb8d | ||
|
|
df65e1c7ad | ||
|
|
01115c1223 | ||
|
|
6de88c3b93 | ||
|
|
9d77827252 | ||
|
|
76fb97624d | ||
|
|
20d6582f51 | ||
|
|
7ebda33793 | ||
|
|
953124aa37 | ||
|
|
ba3451ce5a | ||
|
|
b93591ec32 | ||
|
|
0abfd8da0d | ||
|
|
a9cc4e36c6 | ||
|
|
fe1c963eec | ||
|
|
111d80e88d | ||
|
|
6718862dbe | ||
|
|
0fe1bf8a61 | ||
|
|
10f326eda9 | ||
|
|
cd0d6c1a3d | ||
|
|
3205f2df97 | ||
|
|
5bdbcfcd8d | ||
|
|
a2e2052b30 | ||
|
|
0146ded4f4 | ||
|
|
dccf9dd8f8 | ||
|
|
7816b402bb | ||
|
|
cd4ce30f7c | ||
|
|
8c7e230898 | ||
|
|
42ba696518 | ||
|
|
3f84e60a1f | ||
|
|
baba8b5b73 | ||
|
|
77397c4f21 | ||
|
|
8678091d8f | ||
|
|
aa22170ab4 | ||
|
|
901ec37290 | ||
|
|
21f2ea8b17 | ||
|
|
8219e3d4e2 | ||
|
|
3ed71a61d5 | ||
|
|
18a88a8e8f | ||
|
|
318a72987c |
136
README.md
136
README.md
@@ -1,39 +1,53 @@
|
||||
# Trains Server
|
||||
<div align="center">
|
||||
|
||||
## Auto-Magical Experiment Manager & Version Control for AI - ε Devops Included!
|
||||
<img src="docs/clearml_server_logo.png" width="250px">
|
||||
|
||||
**ClearML - Auto-Magical Suite of tools to streamline your ML workflow
|
||||
</br>Experiment Manager, ML-Ops and Data-Management**
|
||||
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://img.shields.io/badge/status-beta-yellow.svg)
|
||||
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
</div>
|
||||
|
||||
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
|
||||
---
|
||||
<div align="center">
|
||||
|
||||
## Introduction
|
||||
**v0.16 Upgrade Notice**
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
</div>
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
|
||||
|
||||
---
|
||||
|
||||
### ClearML Server
|
||||
#### *Formerly known as Trains Server*
|
||||
|
||||
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
By default, **Trains** is set up to work with the **Trains** demo server, which is open to anyone and resets periodically.
|
||||
In order to host your own server, you will need to launch **trains-server** and point **Trains** to it.
|
||||
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
|
||||
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
|
||||
|
||||
**trains-server** contains the following components:
|
||||
The **ClearML Server** contains the following components:
|
||||
|
||||
* The **Trains** Web-App, a single-page UI for experiment management and browsing
|
||||
* The **ClearML** Web-App, a single-page UI for experiment management and browsing
|
||||
* RESTful API for:
|
||||
* Documenting and logging experiment information, statistics and results
|
||||
* Querying experiments history, logs and results
|
||||
* Locally-hosted file server for storing images and models making them easily accessible using the Web-App
|
||||
|
||||
You can quickly [deploy](#launching-trains-server) your **trains-server** using Docker, AWS EC2 AMI, or Kubernetes.
|
||||
You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server** using Docker, AWS EC2 AMI, or Kubernetes.
|
||||
|
||||
## System design
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
**trains-server** has two supported configurations:
|
||||
The **ClearML Server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
- Web application on port 8080
|
||||
- API service on port 8008
|
||||
@@ -44,11 +58,11 @@ You can quickly [deploy](#launching-trains-server) your **trains-server** using
|
||||
- API service on sub-domain: api.\*.\*
|
||||
- File storage service on sub-domain: files.\*.\*
|
||||
|
||||
## Launching trains-server
|
||||
## Launching The ClearML Server
|
||||
|
||||
### Prerequisites
|
||||
|
||||
The ports 8080/8081/8008 must be available for the **trains-server** services.
|
||||
The ports 8080/8081/8008 must be available for the **ClearML Server** services.
|
||||
|
||||
For example, to see if port `8080` is in use:
|
||||
|
||||
@@ -62,24 +76,24 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
### Launching
|
||||
|
||||
Launch **trains-server** in any of the following formats:
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
|
||||
- Pre-built [GCP Custom Image](https://github.com/allegroai/trains-server/blob/master/docs/install_gcp.md)
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
|
||||
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
|
||||
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
|
||||
## Connecting Trains to your trains-server
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
By default, the **Trains** client is set up to work with the [**Trains** demo server](https://demoapp.trains.allegro.ai/).
|
||||
To have the **Trains** client use your **trains-server** instead:
|
||||
- Run the `trains-init` command for an interactive setup.
|
||||
- Or manually edit `~/trains.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
|
||||
To have the **ClearML** client use your **ClearML Server** instead:
|
||||
- Run the `clearml-init` command for an interactive setup.
|
||||
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
api {
|
||||
# API server on port 8008
|
||||
@@ -92,44 +106,44 @@ To have the **Trains** client use your **trains-server** instead:
|
||||
files_server: "http://localhost:8081"
|
||||
}
|
||||
|
||||
**Note**: If you have set up **trains-server** in a sub-domain configuration, then there is no need to specify a port number,
|
||||
**Note**: If you have set up your **ClearML Server** in a sub-domain configuration, then there is no need to specify a port number,
|
||||
it will be inferred from the http/s scheme.
|
||||
|
||||
After launching the **trains-server** and configuring the **Trains** client to use the **trains-server**,
|
||||
you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in your experiments and view them in your **trains-server** web server,
|
||||
After launching the **ClearML Server** and configuring the **ClearML** client to use the **ClearML Server**,
|
||||
you can [use](https://github.com/allegroai/clearml) **ClearML** in your experiments and view them in your **ClearML Server** web server,
|
||||
for example http://localhost:8080.
|
||||
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
|
||||
For more information about the ClearML client, see [**ClearML**](https://github.com/allegroai/clearml).
|
||||
|
||||
## Trains-Agent Services <a name="services"></a>
|
||||
## ClearML-Agent Services <a name="services"></a>
|
||||
|
||||
As of version 0.15 of **trains-server**, dockerized deployment includes a **Trains-Agent Services** container running as
|
||||
As of version 0.15 of **ClearML Server**, dockerized deployment includes a **ClearML-Agent Services** container running as
|
||||
part of the docker container collection.
|
||||
|
||||
Trains-Agent Services is an extension of Trains-Agent that provides the ability to launch long-lasting jobs
|
||||
ClearML-Agent Services is an extension of ClearML-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Trains-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by Trains-Agent Services will be registered as a new node in the system,
|
||||
ClearML-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by ClearML-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
You can also run the Trains-Agent Services manually, see details in [trains-agent services mode](https://github.com/allegroai/trains-agent#trains-agent-services-mode-)
|
||||
You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/allegroai/clearml-agent#clearml-agent-services-mode-)
|
||||
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
|
||||
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
|
||||
|
||||
## Advanced Functionality
|
||||
|
||||
**trains-server** provides a few additional useful features, which can be manually enabled:
|
||||
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
|
||||
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
|
||||
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
|
||||
|
||||
## Restarting trains-server
|
||||
## Restarting ClearML Server
|
||||
|
||||
To restart the **trains-server**, you must first stop the containers, and then restart them.
|
||||
To restart the **ClearML Server**, you must first stop the containers, and then restart them.
|
||||
|
||||
```bash
|
||||
docker-compose down
|
||||
@@ -138,12 +152,12 @@ To restart the **trains-server**, you must first stop the containers, and then r
|
||||
|
||||
## Upgrading <a name="upgrade"></a>
|
||||
|
||||
**trains-server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker-compose.yml).
|
||||
We strongly encourage you to keep your **trains-server** up to date, by keeping up with the current release.
|
||||
**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker/docker-compose.yml).
|
||||
We strongly encourage you to keep your **ClearML Server** up to date, by keeping up with the current release.
|
||||
|
||||
**Note**: The following upgrade instructions use the Linux OS as an example.
|
||||
|
||||
To upgrade your existing **trains-server** deployment:
|
||||
To upgrade your existing **ClearML Server** deployment:
|
||||
|
||||
1. Shut down the docker containers
|
||||
```bash
|
||||
@@ -152,10 +166,10 @@ To upgrade your existing **trains-server** deployment:
|
||||
|
||||
1. We highly recommend backing up your data directory before upgrading.
|
||||
|
||||
Assuming your data directory is `/opt/trains`, to archive all data into `~/trains_backup.tgz` execute:
|
||||
Assuming your data directory is `/opt/clearml`, to archive all data into `~/clearml_backup.tgz` execute:
|
||||
|
||||
```bash
|
||||
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
|
||||
sudo tar czvf ~/clearml_backup.tgz /opt/clearml/data
|
||||
```
|
||||
|
||||
<details>
|
||||
@@ -163,21 +177,21 @@ To upgrade your existing **trains-server** deployment:
|
||||
|
||||
To restore this example backup, execute:
|
||||
```bash
|
||||
sudo rm -R /opt/trains/data
|
||||
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
|
||||
sudo rm -R /opt/clearml/data
|
||||
sudo tar -xzf ~/clearml_backup.tgz -C /opt/clearml/data
|
||||
```
|
||||
</details>
|
||||
|
||||
1. Download the latest `docker-compose.yml` file.
|
||||
|
||||
```bash
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
|
||||
```
|
||||
|
||||
1. Configure the Trains-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, Trains-Agent Services will use the external
|
||||
public address of the **trains-server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
the Trains-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
1. Configure the ClearML-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
@@ -185,29 +199,29 @@ To upgrade your existing **trains-server** deployment:
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
|
||||
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml pull
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
|
||||
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
|
||||
Additionally, you can always find us at *trains@allegro.ai*
|
||||
Additionally, you can always find us at *clearml@allegro.ai*
|
||||
|
||||
## License
|
||||
|
||||
[Server Side Public License v1.0](https://github.com/mongodb/mongo/blob/master/LICENSE-Community.txt)
|
||||
|
||||
**trains-server** relies on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
|
||||
The **ClearML Server** relies on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
|
||||
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our responsibility as a
|
||||
member of the community to support the projects we love and cherish.
|
||||
We believe the cause for the license change in both cases is more than just,
|
||||
|
||||
6
apiserver/apierrors/__init__.py
Normal file
6
apiserver/apierrors/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .apierror import APIError
|
||||
from .base import BaseError
|
||||
|
||||
from apiserver.apierrors_generator import ErrorsGenerator
|
||||
|
||||
ErrorsGenerator.generate_python_files()
|
||||
@@ -1,9 +1,10 @@
|
||||
class APIError(Exception):
|
||||
def __init__(self, msg, code=500, subcode=0, **_):
|
||||
def __init__(self, msg, code=500, subcode=0, error_data=None, **_):
|
||||
super(APIError, self).__init__()
|
||||
self._msg = msg
|
||||
self._code = code
|
||||
self._subcode = subcode
|
||||
self._error_data = error_data or {}
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
@@ -17,5 +18,9 @@ class APIError(Exception):
|
||||
def subcode(self):
|
||||
return self._subcode
|
||||
|
||||
@property
|
||||
def error_data(self):
|
||||
return self._error_data
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
@@ -1,9 +1,13 @@
|
||||
import six
|
||||
from boltons.typeutils import classproperty
|
||||
from typing import Tuple
|
||||
|
||||
import six
|
||||
from boltons.iterutils import is_collection, remap
|
||||
from boltons.typeutils import classproperty
|
||||
|
||||
from .apierror import APIError
|
||||
|
||||
jsonable_types = (dict, list, tuple, str, int, float, bool, type(None))
|
||||
|
||||
|
||||
class BaseError(APIError):
|
||||
_default_code = 500
|
||||
@@ -19,15 +23,26 @@ class BaseError(APIError):
|
||||
f"{k}={self._format_kwarg(v)}" for k, v in kwargs.items()
|
||||
)
|
||||
message += f": {kwargs_msg}"
|
||||
params = kwargs.copy()
|
||||
params.update(
|
||||
code=self._default_code, subcode=self._default_subcode, msg=message
|
||||
|
||||
super(BaseError, self).__init__(
|
||||
code=self._default_code,
|
||||
subcode=self._default_subcode,
|
||||
msg=message,
|
||||
error_data=self._to_safe_json_types(kwargs),
|
||||
)
|
||||
super(BaseError, self).__init__(**params)
|
||||
|
||||
@staticmethod
|
||||
def _to_safe_json_types(data):
|
||||
def visit(_, k, v):
|
||||
if not isinstance(v, jsonable_types):
|
||||
v = str(v)
|
||||
return k, v
|
||||
|
||||
return remap(data, visit=visit)
|
||||
|
||||
@staticmethod
|
||||
def _format_kwarg(value):
|
||||
if isinstance(value, (tuple, list)):
|
||||
if is_collection(value):
|
||||
return f'({", ".join(str(v) for v in value)})'
|
||||
elif isinstance(value, six.string_types):
|
||||
return value
|
||||
129
apiserver/apierrors/errors.conf
Normal file
129
apiserver/apierrors/errors.conf
Normal file
@@ -0,0 +1,129 @@
|
||||
400 {
|
||||
_: "bad_request"
|
||||
1: ["not_supported", "endpoint is not supported"]
|
||||
2: ["request_path_has_invalid_version", "request path has invalid version"]
|
||||
5: ["invalid_headers", "invalid headers"]
|
||||
6: ["impersonation_error", "impersonation error"]
|
||||
|
||||
10: ["invalid_id", "invalid object id"]
|
||||
11: ["missing_required_fields", "missing required fields"]
|
||||
12: ["validation_error", "validation error"]
|
||||
13: ["fields_not_allowed_for_role", "fields not allowed for role"]
|
||||
14: ["invalid fields", "fields not defined for object"]
|
||||
15: ["fields_conflict", "conflicting fields"]
|
||||
16: ["fields_value_error", "invalid value for fields"]
|
||||
17: ["batch_contains_no_items", "batch request contains no items"]
|
||||
18: ["batch_validation_error", "batch request validation error"]
|
||||
19: ["invalid_lucene_syntax", "malformed lucene query"]
|
||||
20: ["fields_type_error", "invalid type for fields"]
|
||||
21: ["invalid_regex_error", "malformed regular expression"]
|
||||
22: ["invalid_email_address", "malformed email address"]
|
||||
23: ["invalid_domain_name", "malformed domain name"]
|
||||
24: ["not_public_object", "object is not public"]
|
||||
|
||||
# Tasks
|
||||
100: ["task_error", "general task error"]
|
||||
101: ["invalid_task_id", "invalid task id"]
|
||||
102: ["task_validation_error", "task validation error"]
|
||||
110: ["invalid_task_status", "invalid task status"]
|
||||
111: ["task_not_started", "task not started (invalid task status)"]
|
||||
112: ["task_in_progress", "task in progress (invalid task status)"]
|
||||
113: ["task_published", "task published (invalid task status)"]
|
||||
114: ["task_status_unknown", "task unknown (invalid task status)"]
|
||||
120: ["invalid_task_execution_progress", "invalid task execution progress"]
|
||||
121: ["failed_changing_task_status", "failed changing task status. probably someone changed it before you"]
|
||||
122: ["missing_task_fields", "task is missing expected fields"]
|
||||
123: ["task_cannot_be_deleted", "task cannot be deleted"]
|
||||
125: ["task_has_jobs_running", "task has jobs that haven't completed yet"]
|
||||
126: ["invalid_task_type", "invalid task type for this operations"]
|
||||
127: ["invalid_task_input", "invalid task output"]
|
||||
128: ["invalid_task_output", "invalid task output"]
|
||||
129: ["task_publish_in_progress", "Task publish in progress"]
|
||||
130: ["task_not_found", "task not found"]
|
||||
131: ["events_not_added", "events not added"]
|
||||
|
||||
# Models
|
||||
200: ["model_error", "general task error"]
|
||||
201: ["invalid_model_id", "invalid model id"]
|
||||
202: ["model_not_ready", "model is not ready"]
|
||||
203: ["model_is_ready", "model is ready"]
|
||||
204: ["invalid_model_uri", "invalid model URI"]
|
||||
205: ["model_in_use", "model is used by tasks"]
|
||||
206: ["model_creating_task_exists", "task that created this model exists"]
|
||||
|
||||
# Users
|
||||
300: ["invalid_user", "invalid user"]
|
||||
301: ["invalid_user_id", "invalid user id"]
|
||||
302: ["user_id_exists", "user id already exists"]
|
||||
305: ["invalid_preferences_update", "Malformed key and/or value"]
|
||||
|
||||
# Projects
|
||||
401: ["invalid_project_id", "invalid project id"]
|
||||
402: ["project_has_tasks", "project has associated tasks"]
|
||||
403: ["project_not_found", "project not found"]
|
||||
405: ["project_has_models", "project has associated models"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
702: ["queue_not_empty", "queue is not empty"]
|
||||
703: ["invalid_queue_or_task_not_queued", "invalid queue id or task not in queue"]
|
||||
704: ["removed_during_reposition", "task was removed by another party during reposition"]
|
||||
705: ["failed_adding_during_reposition", "failed adding task back to queue during reposition"]
|
||||
706: ["task_already_queued", "failed adding task to queue since task is already queued"]
|
||||
707: ["no_default_queue", "no queue is tagged as the default queue for this company"]
|
||||
708: ["multiple_default_queues", "more than one queue is tagged as the default queue for this company"]
|
||||
|
||||
# Database
|
||||
800: ["data_validation_error", "data validation error"]
|
||||
801: ["expected_unique_data", "value combination already exists"]
|
||||
|
||||
# Workers
|
||||
1001: ["invalid_worker_id", "invalid worker id"]
|
||||
1002: ["worker_registration_failed", "worker registration failed"]
|
||||
1003: ["worker_registered", "worker is already registered"]
|
||||
1004: ["worker_not_registered", "worker is not registered"]
|
||||
1005: ["worker_stats_not_found", "worker stats not found"]
|
||||
|
||||
1104: ["invalid_scroll_id", "Invalid scroll id"]
|
||||
}
|
||||
|
||||
401 {
|
||||
_: "unauthorized"
|
||||
1: ["not_authorized", "unauthorized (not authorized for endpoint)"]
|
||||
2: ["entity_not_allowed", "unauthorized (entity not allowed)"]
|
||||
10: ["bad_auth_type", "unauthorized (bad authentication header type)"]
|
||||
20: ["no_credentials", "unauthorized (missing credentials)"]
|
||||
21: ["bad_credentials", "unauthorized (malformed credentials)"]
|
||||
22: ["invalid_credentials", "unauthorized (invalid credentials)"]
|
||||
30: ["invalid_token", "invalid token"]
|
||||
31: ["blocked_token", "token is blocked"]
|
||||
40: ["invalid_fixed_user", "fixed user ID was not found"]
|
||||
}
|
||||
|
||||
403: {
|
||||
_: "forbidden"
|
||||
10: ["routing_error", "forbidden (routing error)"]
|
||||
12: ["blocked_internal_endpoint", "forbidden (blocked internal endpoint)"]
|
||||
20: ["role_not_allowed", "forbidden (not allowed for role)"]
|
||||
21: ["no_write_permission", "forbidden (modification not allowed)"]
|
||||
}
|
||||
|
||||
500 {
|
||||
_: "server_error"
|
||||
0: ["general_error", "general server error"]
|
||||
1: ["internal_error", "internal server error"]
|
||||
2: ["config_error", "configuration error"]
|
||||
3: ["build_info_error", "build info unavailable or corrupted"]
|
||||
4: ["low_disk_space", "Critical server error! Server reports low or insufficient disk space. Please resolve immediately by allocating additional disk space or freeing up storage space."]
|
||||
10: ["transaction_error", "a transaction call has returned with an error"]
|
||||
# Database-related issues
|
||||
100: ["data_error", "general data error"]
|
||||
101: ["inconsistent_data", "inconsistent data encountered in document"]
|
||||
102: ["database_unavailable", "database is temporarily unavailable"]
|
||||
110: ["update_failed", "update failed"]
|
||||
|
||||
# Index-related issues
|
||||
201: ["missing_index", "missing internal index"]
|
||||
|
||||
9999: ["not_implemented", "action is not yet implemented"]
|
||||
}
|
||||
1
apiserver/apierrors_generator/__init__.py
Normal file
1
apiserver/apierrors_generator/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .errors_generator import ErrorsGenerator
|
||||
4
apiserver/apierrors_generator/__main__.py
Normal file
4
apiserver/apierrors_generator/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .errors_generator import ErrorsGenerator
|
||||
|
||||
if __name__ == '__main__':
|
||||
ErrorsGenerator.generate_python_files()
|
||||
31
apiserver/apierrors_generator/errors_generator.py
Normal file
31
apiserver/apierrors_generator/errors_generator.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from pyhocon import ConfigFactory, ConfigTree
|
||||
|
||||
from .generator import Generator
|
||||
|
||||
|
||||
class ErrorsGenerator:
|
||||
_apierrors_path = Path(__file__).parents[1] / "apierrors"
|
||||
_files = [_apierrors_path / "errors.conf"]
|
||||
|
||||
@classmethod
|
||||
def _get_codes(cls):
|
||||
return {
|
||||
(k, v.pop("_")): v
|
||||
for k, v in reduce(
|
||||
ConfigTree.merge_configs, map(ConfigFactory.parse_file, cls._files),
|
||||
).items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def add_errors_file(cls, path: Union[Path, str]):
|
||||
cls._files.append(path)
|
||||
|
||||
@classmethod
|
||||
def generate_python_files(cls):
|
||||
Generator(cls._apierrors_path / "errors", format_pep8=False).make_errors(
|
||||
cls._get_codes()
|
||||
)
|
||||
@@ -8,9 +8,12 @@ from pathlib import Path
|
||||
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(str(Path(__file__).parent)),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=('py',), default_for_string=False),
|
||||
autoescape=jinja2.select_autoescape(
|
||||
disabled_extensions=("py",), default_for_string=False
|
||||
),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True)
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
|
||||
def env_filter(name=None):
|
||||
@@ -19,14 +22,14 @@ def env_filter(name=None):
|
||||
|
||||
@env_filter()
|
||||
def cls_name(name):
|
||||
delims = list(map(re.escape, (' ', '_')))
|
||||
parts = re.split('|'.join(delims), name)
|
||||
return ''.join(x.capitalize() for x in parts)
|
||||
delims = list(map(re.escape, (" ", "_")))
|
||||
parts = re.split("|".join(delims), name)
|
||||
return "".join(x.capitalize() for x in parts)
|
||||
|
||||
|
||||
class Generator(object):
|
||||
_base_class_name = 'BaseError'
|
||||
_base_class_module = 'apierrors.base'
|
||||
_base_class_name = "BaseError"
|
||||
_base_class_module = "apiserver.apierrors.base"
|
||||
|
||||
def __init__(self, path, format_pep8=True, use_md5=True):
|
||||
self._use_md5 = use_md5
|
||||
@@ -35,29 +38,37 @@ class Generator(object):
|
||||
self._path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_init_file(self, path):
|
||||
(self._path / path / '__init__.py').write_bytes('')
|
||||
(self._path / path / "__init__.py").write_bytes(b"")
|
||||
|
||||
def _do_render(self, file, template, context):
|
||||
with file.open('w') as f:
|
||||
with file.open("w") as f:
|
||||
result = template.render(
|
||||
base_class_name=self._base_class_name,
|
||||
base_class_module=self._base_class_module,
|
||||
**context)
|
||||
**context
|
||||
)
|
||||
if self._format_pep8:
|
||||
result = autopep8.fix_code(result, options={'aggressive': 1, 'verbose': 0, 'max_line_length': 120})
|
||||
import autopep8
|
||||
|
||||
result = autopep8.fix_code(
|
||||
result,
|
||||
options={"aggressive": 1, "verbose": 0, "max_line_length": 120},
|
||||
)
|
||||
f.write(result)
|
||||
|
||||
def _make_section(self, name, code, subcodes):
|
||||
self._do_render(
|
||||
file=(self._path / name).with_suffix('.py'),
|
||||
template=env.get_template('templates/section.jinja2'),
|
||||
context=dict(code=code, subcodes=list(subcodes.items()),))
|
||||
file=(self._path / name).with_suffix(".py"),
|
||||
template=env.get_template("templates/section.jinja2"),
|
||||
context=dict(code=code, subcodes=list(subcodes.items()),),
|
||||
)
|
||||
|
||||
def _make_init(self, sections):
|
||||
self._do_render(
|
||||
file=(self._path / '__init__.py'),
|
||||
template=env.get_template('templates/init.jinja2'),
|
||||
context=dict(sections=sections,))
|
||||
file=(self._path / "__init__.py"),
|
||||
template=env.get_template("templates/init.jinja2"),
|
||||
context=dict(sections=sections,),
|
||||
)
|
||||
|
||||
def _key_to_str(self, data):
|
||||
if isinstance(data, dict):
|
||||
@@ -66,11 +77,11 @@ class Generator(object):
|
||||
|
||||
def _calc_digest(self, data):
|
||||
data = json.dumps(self._key_to_str(data), sort_keys=True)
|
||||
return hashlib.md5(data.encode('utf8')).hexdigest()
|
||||
return hashlib.md5(data.encode("utf8")).hexdigest()
|
||||
|
||||
def make_errors(self, errors):
|
||||
digest = None
|
||||
digest_file = self._path / 'digest.md5'
|
||||
digest_file = self._path / "digest.md5"
|
||||
if self._use_md5:
|
||||
digest = self._calc_digest(errors)
|
||||
if digest_file.is_file():
|
||||
@@ -79,7 +90,7 @@ class Generator(object):
|
||||
|
||||
self._make_init(errors)
|
||||
for (code, section_name), subcodes in errors.items():
|
||||
self._make_section(section_name, code, subcodes)
|
||||
self._make_section(section_name, int(code), subcodes)
|
||||
|
||||
if self._use_md5:
|
||||
digest_file.write_text(digest)
|
||||
@@ -5,5 +5,5 @@ from {{ base_class_module }} import {{ base_class_name }}
|
||||
{% for subcode, (name, msg) in subcodes %}
|
||||
|
||||
|
||||
{{ error_class(name|cls_name, msg, code, subcode) -}}
|
||||
{{ error_class(name|cls_name, msg, code, subcode|int) -}}
|
||||
{% endfor %}
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union, Type, Iterable
|
||||
|
||||
@@ -9,11 +7,29 @@ from jsonmodels import fields
|
||||
from jsonmodels.fields import _LazyType, NotSet
|
||||
from jsonmodels.models import Base as ModelBase
|
||||
from jsonmodels.validators import Enum as EnumValidator
|
||||
from luqum.parser import parser, ParseError
|
||||
from mongoengine.base import BaseDocument
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apierrors import errors
|
||||
from utilities.json import loads, dumps
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.utilities.json import loads, dumps
|
||||
|
||||
|
||||
class EmailField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if email_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidEmailAddress()
|
||||
|
||||
|
||||
class DomainField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if domain_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
@@ -35,6 +51,8 @@ class ListField(fields.ListField):
|
||||
try:
|
||||
return super(ListField, self)._cast_value(value)
|
||||
except TypeError:
|
||||
if len(self.items_types) == 1 and issubclass(self.items_types[0], Enum):
|
||||
return self.items_types[0](value)
|
||||
return value
|
||||
|
||||
def validate_single_value(self, item):
|
||||
@@ -43,6 +61,12 @@ class ListField(fields.ListField):
|
||||
item.validate()
|
||||
|
||||
|
||||
# since there is no distinction between None and empty DictField
|
||||
# this value can be used as sentinel in order to distinguish
|
||||
# between not set and empty DictField
|
||||
DictFieldNotSet = {}
|
||||
|
||||
|
||||
class DictField(fields.BaseField):
|
||||
types = (dict,)
|
||||
|
||||
@@ -71,6 +95,31 @@ class DictField(fields.BaseField):
|
||||
for type_ in value_types
|
||||
)
|
||||
|
||||
def parse_value(self, values):
|
||||
"""Cast value to proper collection."""
|
||||
result = self.get_default_value()
|
||||
|
||||
if values is None:
|
||||
return result
|
||||
|
||||
if not self.value_types or not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
return {key: self._cast_value(value) for key, value in values.items()}
|
||||
|
||||
def _cast_value(self, value):
|
||||
if isinstance(value, self.value_types):
|
||||
return value
|
||||
else:
|
||||
if len(self.value_types) != 1:
|
||||
tpl = 'Cannot decide which type to choose from "{types}".'
|
||||
raise jsonmodels.errors.ValidationError(
|
||||
tpl.format(
|
||||
types=', '.join([t.__name__ for t in self.value_types])
|
||||
)
|
||||
)
|
||||
return self.value_types[0](**value)
|
||||
|
||||
def validate(self, value):
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
@@ -96,6 +145,15 @@ class DictField(fields.BaseField):
|
||||
)
|
||||
)
|
||||
|
||||
def _elem_to_struct(self, value):
|
||||
try:
|
||||
return value.to_struct()
|
||||
except AttributeError:
|
||||
return value
|
||||
|
||||
def to_struct(self, values):
|
||||
return {k: self._elem_to_struct(v) for k, v in values.items()}
|
||||
|
||||
|
||||
class IntField(fields.IntField):
|
||||
def parse_value(self, value):
|
||||
@@ -105,23 +163,6 @@ class IntField(fields.IntField):
|
||||
return value
|
||||
|
||||
|
||||
def validate_lucene_query(value):
|
||||
if value == "":
|
||||
return
|
||||
try:
|
||||
parser.parse(value)
|
||||
except ParseError as e:
|
||||
raise errors.bad_request.InvalidLuceneSyntax(error=e)
|
||||
|
||||
|
||||
class LuceneQueryField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super(LuceneQueryField, self).validate(value)
|
||||
if value is None:
|
||||
return
|
||||
validate_lucene_query(value)
|
||||
|
||||
|
||||
class NullableEnumValidator(EnumValidator):
|
||||
"""Validator for enums that allows a None value."""
|
||||
|
||||
@@ -189,24 +230,6 @@ class ActualEnumField(fields.StringField):
|
||||
return super().to_struct(value.value)
|
||||
|
||||
|
||||
class EmailField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if email_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidEmailAddress()
|
||||
|
||||
|
||||
class DomainField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if domain_validator(value) is not True:
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
@@ -214,3 +237,67 @@ class JsonSerializableMixin:
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
||||
|
||||
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
|
||||
class _Wrapped(cls):
|
||||
_callable_default = None
|
||||
|
||||
def get_default_value(self):
|
||||
if self._callable_default:
|
||||
return self._callable_default()
|
||||
return super(_Wrapped, self).get_default_value()
|
||||
|
||||
def __init__(self, *args, default=None, **kwargs):
|
||||
if default and callable(default):
|
||||
self._callable_default = default
|
||||
default = default()
|
||||
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
|
||||
|
||||
return _Wrapped
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
DictField representing mongoengine field names/value mapping.
|
||||
Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update
|
||||
operators.
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_mongo_value(value):
|
||||
if isinstance(value, BaseDocument):
|
||||
return value.to_mongo()
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
return {
|
||||
k: v
|
||||
for k, v in (self._normalize_mongo_field_path(*p) for p in value.items())
|
||||
if k is not None
|
||||
}
|
||||
@@ -2,10 +2,10 @@ from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField, D
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Max, Enum
|
||||
|
||||
from apimodels import ListField, EnumField
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
from database.utils import get_options
|
||||
from apiserver.apimodels import ListField, EnumField
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.auth import Role
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class GetTokenRequest(Base):
|
||||
28
apiserver/apimodels/base.py
Normal file
28
apiserver/apimodels/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import MongoengineFieldsDict, ListField
|
||||
|
||||
|
||||
class UpdateResponse(models.Base):
|
||||
updated = fields.IntField(required=True)
|
||||
fields = MongoengineFieldsDict()
|
||||
|
||||
|
||||
class PagedRequest(models.Base):
|
||||
page = fields.IntField()
|
||||
page_size = fields.IntField()
|
||||
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
||||
|
||||
class MoveRequest(models.Base):
|
||||
ids = ListField([str], validators=Length(minimum_value=1))
|
||||
project = fields.StringField()
|
||||
project_name = fields.StringField()
|
||||
34
apiserver/apimodels/custom_validators/__init__.py
Normal file
34
apiserver/apimodels/custom_validators/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import validators
|
||||
from jsonmodels.errors import ValidationError
|
||||
|
||||
|
||||
class ForEach(object):
|
||||
def __init__(self, validator):
|
||||
self.validator = validator
|
||||
|
||||
def validate(self, values):
|
||||
for value in values:
|
||||
self.validator.validate(value)
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
return self.validator.modify_schema(field_schema)
|
||||
|
||||
|
||||
class Hostname(object):
|
||||
|
||||
def validate(self, value):
|
||||
if validators.domain(value) is not True:
|
||||
raise ValidationError(f"Value '{value}' is not a valid hostname")
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
field_schema["format"] = "hostname"
|
||||
|
||||
|
||||
class Email(object):
|
||||
|
||||
def validate(self, value):
|
||||
if validators.email(value) is not True:
|
||||
raise ValidationError(f"Value '{value}' is not a valid email address")
|
||||
|
||||
def modify_schema(self, field_schema):
|
||||
field_schema["format"] = "email"
|
||||
@@ -1,17 +1,20 @@
|
||||
from typing import Sequence
|
||||
from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
from apiserver.apimodels import ListField, IntField, ActualEnumField
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=10000)
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
@@ -21,7 +24,15 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
items_types=str,
|
||||
validators=[
|
||||
Length(
|
||||
minimum_value=1,
|
||||
maximum_value=config.get(
|
||||
"services.tasks.multi_task_histogram_limit", 10
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -40,12 +51,35 @@ class DebugImagesRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class TaskMetricVariant(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
|
||||
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
asc = auto()
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
33
apiserver/apimodels/login.py
Normal file
33
apiserver/apimodels/login.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField, ListField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import DictField, callable_default
|
||||
|
||||
|
||||
class GetSupportedModesRequest(Base):
|
||||
state = StringField(help_text="ASCII base64 encoded application state")
|
||||
callback_url_prefix = StringField()
|
||||
|
||||
|
||||
class BasicGuestMode(Base):
|
||||
enabled = BoolField(default=False)
|
||||
name = StringField()
|
||||
username = StringField()
|
||||
password = StringField()
|
||||
|
||||
|
||||
class BasicMode(Base):
|
||||
enabled = BoolField(default=False)
|
||||
guest = callable_default(EmbeddedField)(BasicGuestMode, default=BasicGuestMode)
|
||||
|
||||
|
||||
class ServerErrors(Base):
|
||||
missed_es_upgrade = BoolField(default=False)
|
||||
es_connection_error = BoolField(default=False)
|
||||
|
||||
|
||||
class GetSupportedModesResponse(Base):
|
||||
basic = EmbeddedField(BasicMode)
|
||||
server_errors = EmbeddedField(ServerErrors)
|
||||
sso = DictField([str, type(None)])
|
||||
sso_providers = ListField([dict])
|
||||
@@ -1,9 +1,13 @@
|
||||
from jsonmodels import models, fields
|
||||
from six import string_types
|
||||
|
||||
from apimodels import ListField, DictField
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
projects = fields.ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
@@ -1,7 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apimodels import ListField
|
||||
from apimodels.organization import TagsRequest
|
||||
from apiserver.apimodels import ListField, ActualEnumField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
@@ -13,11 +14,11 @@ class GetHyperParamReq(ProjectReq):
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(ProjectReq):
|
||||
projects = ListField(str)
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
|
||||
@@ -2,7 +2,7 @@ from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import ListField
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class GetDefaultResp(Base):
|
||||
221
apiserver/apimodels/tasks.py
Normal file
221
apiserver/apimodels/tasks.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apiserver.apimodels import DictField, ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskType,
|
||||
ArtifactModes,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class ArtifactTypeData(models.Base):
|
||||
preview = StringField()
|
||||
content_type = StringField()
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class Artifact(models.Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(
|
||||
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = IntField()
|
||||
timestamp = IntField()
|
||||
type_data = EmbeddedField(ArtifactTypeData)
|
||||
display_data = ListField([list])
|
||||
|
||||
|
||||
class StartedResponse(UpdateResponse):
|
||||
started = IntField()
|
||||
|
||||
|
||||
class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
dequeued = IntField()
|
||||
|
||||
|
||||
class ResetResponse(UpdateResponse):
|
||||
deleted_indices = ListField(items_types=six.string_types)
|
||||
dequeued = DictField()
|
||||
frames = DictField()
|
||||
events = DictField()
|
||||
model_deleted = IntField()
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class UpdateRequest(TaskRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
queue = StringField()
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
requirements = DictField(required=True)
|
||||
|
||||
|
||||
class PublishRequest(UpdateRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishResponse(UpdateResponse):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(models.Base):
|
||||
"""
|
||||
This is a partial description of task can be updated incrementally
|
||||
"""
|
||||
|
||||
|
||||
class CreateRequest(TaskData):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskType)))
|
||||
|
||||
|
||||
class PingRequest(TaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
new_task_tags = ListField([str])
|
||||
new_task_system_tags = ListField([str])
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
new_task_hyperparams = DictField()
|
||||
new_task_configuration = DictField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
new_project_name = StringField()
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArtifactId(models.Base):
|
||||
key = StringField(required=True)
|
||||
mode = StringField(
|
||||
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
|
||||
|
||||
class DeleteArtifactsRequest(TaskRequest):
|
||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
tasks = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class GetHyperParamsRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class HyperParamItem(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ReplaceHyperparams(object):
|
||||
none = "none"
|
||||
section = "section"
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_hyperparams = StringField(
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
section = StringField(required=True, validators=Length(minimum_value=1))
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
names = ListField([str])
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=1))
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArchiveRequest(MultiTaskRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class ArchiveResponse(models.Base):
|
||||
archived = IntField()
|
||||
@@ -1,7 +1,7 @@
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import DictField
|
||||
from apiserver.apimodels import DictField
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
@@ -12,13 +12,14 @@ from jsonmodels.fields import (
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@@ -67,12 +68,14 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
register_timeout = IntField(required=True)
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
@@ -1,17 +1,17 @@
|
||||
from datetime import datetime
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from bll.user import UserBLL
|
||||
from config import config
|
||||
from config.info import get_version, get_build_number
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User, Role, Credentials
|
||||
from database.model.company import Company
|
||||
from service_repo import APICall, ServiceRepo
|
||||
from service_repo.auth import Identity, Token, get_client_id, get_secret_key
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_version, get_build_number
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Role, Credentials
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.service_repo import APICall, ServiceRepo
|
||||
from apiserver.service_repo.auth import Identity, Token, get_client_id, get_secret_key
|
||||
|
||||
log = config.logger("AuthBLL")
|
||||
|
||||
@@ -13,15 +13,19 @@ from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.metrics import MetricEventStats
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
@@ -46,6 +50,7 @@ class MetricScrollState(Base):
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -55,24 +60,14 @@ class DebugImagesResult(object):
|
||||
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@property
|
||||
def _max_workers(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
@@ -84,13 +79,12 @@ class DebugImagesIterator:
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugImagesResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
|
||||
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
@@ -105,7 +99,7 @@ class DebugImagesIterator:
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state_)
|
||||
self._reinit_outdated_metric_states(company_id, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
@@ -113,12 +107,12 @@ class DebugImagesIterator:
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
es_index=es_index,
|
||||
company_id=company_id,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
),
|
||||
@@ -129,7 +123,7 @@ class DebugImagesIterator:
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||
self, company_id, state: DebugImageEventsScrollState
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
@@ -149,10 +143,10 @@ class DebugImagesIterator:
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
|
||||
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
|
||||
)
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE in stats.event_stats_by_type
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
]
|
||||
|
||||
update_times = dict(
|
||||
@@ -170,14 +164,14 @@ class DebugImagesIterator:
|
||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||
*(
|
||||
self._init_metric_states(
|
||||
es_index,
|
||||
company_id,
|
||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
self, es_index, metrics: Sequence[Tuple[str, str]]
|
||||
self, company_id: str, metrics: Sequence[Tuple[str, str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
@@ -186,18 +180,20 @@ class DebugImagesIterator:
|
||||
for (task, metric) in metrics:
|
||||
tasks[task].append(metric)
|
||||
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
pool.map(
|
||||
partial(self._init_metric_states_for_task, es_index=es_index),
|
||||
partial(
|
||||
self._init_metric_states_for_task, company_id=company_id
|
||||
),
|
||||
tasks.items(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Sequence[str]], es_index
|
||||
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
@@ -208,21 +204,27 @@ class DebugImagesIterator:
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"urls": {
|
||||
@@ -251,7 +253,12 @@ class DebugImagesIterator:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -283,7 +290,7 @@ class DebugImagesIterator:
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
metric: MetricScrollState,
|
||||
es_index: str,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
) -> Tuple:
|
||||
@@ -298,6 +305,7 @@ class DebugImagesIterator:
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
@@ -368,13 +376,14 @@ class DebugImagesIterator:
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
@@ -387,7 +396,12 @@ class DebugImagesIterator:
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
375
apiserver/bll/event/debug_sample_history.py
Normal file
375
apiserver/bll/event/debug_sample_history.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import operator
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugSampleHistoryState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
reached_first: bool = BoolField()
|
||||
reached_last: bool = BoolField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugSampleHistoryResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class DebugSampleHistory:
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugSampleHistoryState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_debug_image(
|
||||
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
Get the debug image for next/prev variant on the current iteration
|
||||
If does not exist then try getting image for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = DebugSampleHistoryResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
image = self._get_next_for_current_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
) or self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
if not image:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(image=image, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def _fill_res_and_update_state(
|
||||
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
|
||||
):
|
||||
state.variant = image["variant"]
|
||||
state.iteration = image["iter"]
|
||||
res.event = image
|
||||
var_state = first(s for s in state.variant_states if s.name == state.variant)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or image is found
|
||||
"""
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if cmp(var_state.name, state.variant)
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"metric": state.metric}},
|
||||
{"terms": {"variant": [v.name for v in variants]}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"variant": "desc" if navigate_earlier else "asc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_current_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the image falls in invalid range are discarded
|
||||
If no suitable image is found then None is returned
|
||||
"""
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"metric": state.metric}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = state.variant_states
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
]
|
||||
must_conditions.append({"bool": {"should": variants_conditions}})
|
||||
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_another_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_debug_image_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
Get the debug image for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = DebugSampleHistoryResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugSampleHistoryState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugSampleHistoryState):
|
||||
if state_.task != task or state_.metric != metric:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugSampleHistoryState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(s for s in state.variant_states if s.name == variant)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_for_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
image=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
|
||||
variant_iterations = self._get_variant_iterations(
|
||||
company_id=company_id, task=state.task, metric=state.metric
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for var_name, min_iter, max_iter in variant_iterations
|
||||
]
|
||||
|
||||
def _get_variant_iterations(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variants: Optional[Sequence[str]] = None,
|
||||
) -> Sequence[Tuple[str, int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported images
|
||||
The min iteration is the lowest iteration that contains non-recycled image url
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if variants:
|
||||
must.append({"terms": {"variant": variants}})
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
# all variants that sent debug images
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_iterations"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(
|
||||
es_res, ("aggregations", "variants", "buckets")
|
||||
)
|
||||
]
|
||||
@@ -1,36 +1,59 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import zlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from bll.event.event_metrics import EventMetrics, EventType
|
||||
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import flatten_nested_items
|
||||
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
get_index_name,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
)
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from apiserver.bll.event.event_metrics import EventMetrics
|
||||
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@@ -40,7 +63,8 @@ class EventBLL(object):
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
|
||||
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
@@ -79,6 +103,7 @@ class EventBLL(object):
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
event_type = event.get("type")
|
||||
@@ -130,21 +155,19 @@ class EventBLL(object):
|
||||
event["metric"] = event.get("metric") or ""
|
||||
event["variant"] = event.get("variant") or ""
|
||||
|
||||
index_name = EventMetrics.get_index_name(company_id, event_type)
|
||||
index_name = get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
|
||||
if event_type != "log":
|
||||
if event_type != EventType.task_log.value:
|
||||
es_action["_id"] = self._get_event_id(event)
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
@@ -162,6 +185,21 @@ class EventBLL(object):
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
if action["_source"]["type"] == EventType.metrics_plot.value
|
||||
]
|
||||
if plot_actions:
|
||||
self.validate_and_compress_plots(
|
||||
plot_actions,
|
||||
validate_json=config.get("services.events.validate_plot_str", False),
|
||||
compression_threshold=config.get(
|
||||
"services.events.plot_compression_threshold", 100_000
|
||||
),
|
||||
)
|
||||
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
@@ -178,7 +216,7 @@ class EventBLL(object):
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
@@ -205,15 +243,58 @@ class EventBLL(object):
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
|
||||
added = min(added, len(actions))
|
||||
|
||||
if not added:
|
||||
raise errors.bad_request.EventsNotAdded(**errors_per_type)
|
||||
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def validate_and_compress_plots(
|
||||
self,
|
||||
plot_events: Sequence[dict],
|
||||
validate_json: bool,
|
||||
compression_threshold: int,
|
||||
):
|
||||
for event in plot_events:
|
||||
validate = validate_json and not event.pop("skip_validation", False)
|
||||
plot_str = event.get(PlotFields.plot_str)
|
||||
if not plot_str:
|
||||
event[PlotFields.plot_len] = 0
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = False
|
||||
continue
|
||||
|
||||
plot_len = len(plot_str)
|
||||
event[PlotFields.plot_len] = plot_len
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
|
||||
if compression_threshold and plot_len >= compression_threshold:
|
||||
event[PlotFields.plot_data] = base64.encodebytes(
|
||||
zlib.compress(plot_str.encode(), level=1)
|
||||
).decode("ascii")
|
||||
event.pop(PlotFields.plot_str, None)
|
||||
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def uncompress_plots(self, plot_events: Sequence[dict]):
|
||||
for event in plot_events:
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_json(text: str) -> bool:
|
||||
"""Check str for valid json"""
|
||||
if not text:
|
||||
return False
|
||||
try:
|
||||
loads(text)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
@@ -315,24 +396,22 @@ class EventBLL(object):
|
||||
|
||||
def scroll_task_events(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
order,
|
||||
event_type=None,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
order: str,
|
||||
event_type: EventType,
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
size = min(batch_size, 10000)
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return [], None, 0
|
||||
|
||||
es_req = {
|
||||
@@ -342,20 +421,25 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, scroll="1h", routing=task_id
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
):
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
@@ -364,20 +448,22 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -387,13 +473,14 @@ class EventBLL(object):
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -413,25 +500,41 @@ class EventBLL(object):
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = "plot"
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
event_type = EventType.metrics_plot
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@@ -451,37 +554,52 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(tasks)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
# scroll id may be missing when queering a totally empty DB
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
self.uncompress_plots(events)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type=None,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
event_type: EventType,
|
||||
metric=None,
|
||||
variant=None,
|
||||
last_iter_count=None,
|
||||
@@ -489,36 +607,34 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@@ -534,36 +650,36 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def get_metrics_and_variants(self, company_id, task_id, event_type):
|
||||
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
def get_metrics_and_variants(
|
||||
self, company_id: str, task_id: str, event_type: EventType
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
@@ -572,13 +688,15 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -590,7 +708,9 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@@ -601,10 +721,9 @@ class EventBLL(object):
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id, task_id):
|
||||
es_index = EventMetrics.get_index_name(company_id, "training_stats_scalar")
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
@@ -621,15 +740,15 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
@@ -659,7 +778,9 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@@ -686,9 +807,8 @@ class EventBLL(object):
|
||||
return metrics, max_timestamp
|
||||
|
||||
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
||||
|
||||
es_index = EventMetrics.get_index_name(company_id, "training_stats_vector")
|
||||
if not self.es.indices.exists(es_index):
|
||||
event_type = EventType.metrics_vector
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return [], []
|
||||
|
||||
es_req = {
|
||||
@@ -706,7 +826,9 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@@ -716,8 +838,10 @@ class EventBLL(object):
|
||||
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(self, es_index, task_id, event_type, iters):
|
||||
if not self.es.indices.exists(es_index):
|
||||
def get_last_iters(
|
||||
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
@@ -727,17 +851,18 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@@ -756,11 +881,14 @@ class EventBLL(object):
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
66
apiserver/bll/event/event_common.py
Normal file
66
apiserver/bll/event/event_common.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
all = "*"
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
|
||||
|
||||
@classproperty
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def max_metrics_count(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_count", 100)
|
||||
|
||||
@classproperty
|
||||
def max_variants_count(self):
|
||||
return config.get("services.events.events_retrieval.max_variants_count", 100)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id}"
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
if not es.indices.exists(es_index):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def search_company_events(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
body: dict,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.search(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
429
apiserver/bll/event/event_metrics.py
Normal file
429
apiserver/bll/event/event_metrics.py
Normal file
@@ -0,0 +1,429 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average,
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
key=key,
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company", "company_origin"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.get_index_company() for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
event_type = EventType.metrics_scalar
|
||||
company_id = next(iter(companies))
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return res
|
||||
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
index_range = max_index - min_index + 1
|
||||
interval = max(1, math.ceil(float(index_range) / samples))
|
||||
max_samples = math.ceil(float(index_range) / interval)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
interval,
|
||||
max_samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
aggs=aggs,
|
||||
task_id=task_id,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
"name": variant["key"],
|
||||
**key.get_iterations_data(variant),
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
return {
|
||||
key: {**value, "aggs": {**value.get("aggs", {}), **average_agg}}
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
aggs: dict,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
),
|
||||
task_ids,
|
||||
)
|
||||
return list(zip(task_ids, res))
|
||||
|
||||
def _get_task_metrics(
|
||||
self, task_id: str, company_id: str, event_type: EventType
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
|
||||
]
|
||||
127
apiserver/bll/event/log_events_iterator.py
Normal file
127
apiserver/bll/event/log_events_iterator.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = EventType.task_log
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
@@ -4,9 +4,9 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from utilities.stringenum import StringEnum
|
||||
from bll.util import extract_properties_to_lists
|
||||
from config import config
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.bll.util import extract_properties_to_lists
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
18
apiserver/bll/model/__init__.py
Normal file
18
apiserver/bll/model/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.utils import get_company_or_none_constraint
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
98
apiserver/bll/organization/__init__.py
Normal file
98
apiserver/bll/organization/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from .tags_cache import _TagsCache
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company_id: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company_id, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company_id,
|
||||
include_system=include_system,
|
||||
filter_=filter_,
|
||||
project=project,
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company_id, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company_id, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
|
||||
@classmethod
|
||||
def get_parent_tasks(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
@@ -1,17 +1,13 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from utilities import json
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
@@ -20,6 +16,8 @@ _settings_prefix = "services.organization"
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
_dummy_tag = "__dummy__"
|
||||
# prepend our list in redis with this tag since empty lists are auto deleted
|
||||
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
@@ -31,12 +29,12 @@ class _TagsCache:
|
||||
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company: str,
|
||||
company_id: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company)
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
@@ -48,7 +46,7 @@ class _TagsCache:
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company: str,
|
||||
company_id: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
@@ -64,12 +62,12 @@ class _TagsCache:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
|
||||
key_parts = [field, company_id, project, self.db_cls.__name__, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
company_id: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
@@ -83,27 +81,29 @@ class _TagsCache:
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
redis_keys = [
|
||||
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
|
||||
for f in fields
|
||||
]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
for field in fields:
|
||||
redis_key = self._get_tags_cache_key(
|
||||
company_id, field=field, project=project, filter_=filter_
|
||||
)
|
||||
cached_tags = self.redis.lrange(redis_key, 0, -1)
|
||||
if cached_tags:
|
||||
tags = [c.decode() for c in cached_tags[1:]]
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, project, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
tags = list(
|
||||
self._get_tags_from_db(
|
||||
company_id, field=field, project=project, filter_=filter_
|
||||
)
|
||||
)
|
||||
self.redis.rpush(redis_key, self._dummy_tag, *tags)
|
||||
self.redis.expire(redis_key, self._tags_cache_expiration_seconds)
|
||||
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
|
||||
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
@@ -119,22 +119,22 @@ class _TagsCache:
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company, projects=[project], fields=fields)
|
||||
self._delete_redis_keys(company_id, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company: str, projects: Sequence[str]):
|
||||
def reset_tags(self, company_id: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company,
|
||||
company_id,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
self, company_id: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company, field=f, project=p) + "*"
|
||||
self._get_tags_cache_key(company_id, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
@@ -142,52 +142,3 @@ class _TagsCache:
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_, project=project
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
137
apiserver/bll/project/project_bll.py
Normal file
137
apiserver/bll/project/project_bll.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence, Optional, Type
|
||||
|
||||
from mongoengine import Q, Document
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
user: str,
|
||||
company: str,
|
||||
name: str,
|
||||
description: str,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new project.
|
||||
Returns project ID
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
company=company,
|
||||
name=name,
|
||||
description=description,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
project.save()
|
||||
return project.id
|
||||
|
||||
@classmethod
|
||||
def find_or_create(
|
||||
cls,
|
||||
user: str,
|
||||
company: str,
|
||||
project_name: str,
|
||||
description: str,
|
||||
project_id: str = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find a project named `project_name` or create a new one.
|
||||
Returns project ID
|
||||
"""
|
||||
if not project_id and not project_name:
|
||||
raise ValueError("project id or name required")
|
||||
|
||||
if project_id:
|
||||
project = Project.objects(company=company, id=project_id).only("id").first()
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
return project_id
|
||||
|
||||
project = Project.objects(company=company, name=project_name).only("id").first()
|
||||
if project:
|
||||
return project.id
|
||||
|
||||
return cls.create(
|
||||
user=user,
|
||||
company=company,
|
||||
name=project_name,
|
||||
description=description,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def move_under_project(
|
||||
cls,
|
||||
entity_cls: Type[Document],
|
||||
user: str,
|
||||
company: str,
|
||||
ids: Sequence[str],
|
||||
project: str = None,
|
||||
project_name: str = None,
|
||||
):
|
||||
"""
|
||||
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
|
||||
"""
|
||||
with TimingContext("mongo", "move_under_project"):
|
||||
project = cls.find_or_create(
|
||||
user=user,
|
||||
company=company,
|
||||
project_id=project,
|
||||
project_name=project_name,
|
||||
description="Auto-generated during move",
|
||||
)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
|
||||
|
||||
return project
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, Sequence, Iterable, Union
|
||||
|
||||
from config import config
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Callable, Sequence, Optional, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import database
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.queue.queue_metrics import QueueMetrics
|
||||
from bll.workers import WorkerBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.queue import Queue, Entry
|
||||
from apiserver import database
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue.queue_metrics import QueueMetrics
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -5,20 +5,19 @@ from typing import Sequence
|
||||
import elasticsearch.helpers
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import es_factory
|
||||
from apierrors.errors import bad_request
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.queue import Queue, Entry
|
||||
from timing_context import TimingContext
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors.errors import bad_request
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
DOC_TYPE = "metrics"
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
@@ -66,7 +65,6 @@ class QueueMetrics:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type=self.EsKeys.DOC_TYPE,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
@@ -93,7 +91,6 @@ class QueueMetrics:
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
doc_type=self.EsKeys.DOC_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -109,7 +106,7 @@ class QueueMetrics:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
@@ -3,8 +3,8 @@ from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
import database
|
||||
from timing_context import TimingContext
|
||||
from apiserver import database
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -6,7 +6,7 @@ from time import sleep
|
||||
import attr
|
||||
import psutil
|
||||
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
|
||||
class ResourceMonitor(Thread):
|
||||
@@ -11,18 +11,18 @@ import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from bll.util import get_server_uuid
|
||||
from bll.workers import WorkerStats, WorkerBLL
|
||||
from config import config
|
||||
from config.info import get_deployment_type
|
||||
from database.model import Company, User
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from utilities import safe_get
|
||||
from utilities.json import dumps
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from version import __version__ as current_version
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.bll.util import get_server_uuid
|
||||
from apiserver.bll.workers import WorkerStats, WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_deployment_type
|
||||
from apiserver.database.model import Company, User
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.json import dumps
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
from apiserver.version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -237,7 +237,6 @@ class StatisticsReporter:
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -4,5 +4,4 @@ from .utils import (
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
ParameterKeyEscaper,
|
||||
)
|
||||
97
apiserver/bll/task/artifacts.py
Normal file
97
apiserver/bll/task/artifacts.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from hashlib import md5
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
|
||||
def get_artifact_id(artifact: dict):
|
||||
"""
|
||||
Calculate id from 'key' and 'mode' fields
|
||||
Return hash on on the id so that it will not contain mongo illegal characters
|
||||
"""
|
||||
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||
return f"{key_hash}_{mode}"
|
||||
|
||||
|
||||
def artifacts_prepare_for_save(fields: dict):
|
||||
artifacts_field = ("execution", "artifacts")
|
||||
artifacts = nested_get(fields, artifacts_field)
|
||||
if artifacts is None:
|
||||
return
|
||||
|
||||
nested_set(
|
||||
fields, artifacts_field, value={get_artifact_id(a): a for a in artifacts}
|
||||
)
|
||||
|
||||
|
||||
def artifacts_unprepare_from_saved(fields):
|
||||
artifacts_field = ("execution", "artifacts")
|
||||
artifacts = nested_get(fields, artifacts_field)
|
||||
if artifacts is None:
|
||||
return
|
||||
|
||||
nested_set(
|
||||
fields,
|
||||
artifacts_field,
|
||||
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
||||
)
|
||||
|
||||
|
||||
class Artifacts:
|
||||
@classmethod
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "update_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
245
apiserver/bll/task/hyperparams.py
Normal file
245
apiserver/bll/task/hyperparams.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import (
|
||||
HyperParamKey,
|
||||
HyperParamItem,
|
||||
ReplaceHyperparams,
|
||||
Configuration,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
class HyperParams:
|
||||
_properties_section = "properties"
|
||||
|
||||
@classmethod
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_params_list(
|
||||
cls, items: Dict[str, Dict[str, ParamsItem]]
|
||||
) -> Sequence[dict]:
|
||||
ret = list(chain.from_iterable(v.values() for v in items.values()))
|
||||
return [
|
||||
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _normalize_params(cls, params: Sequence) -> bool:
|
||||
"""
|
||||
Lower case properties section and return True if it is the only section
|
||||
"""
|
||||
for p in params:
|
||||
if p.section.lower() == cls._properties_section:
|
||||
p.section = cls._properties_section
|
||||
|
||||
return all(p.section == cls._properties_section for p in params)
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{mongoengine_safe(section)}"
|
||||
] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
sections = iterutils.bucketize(items, key=attrgetter("section"))
|
||||
return {
|
||||
ParameterKeyEscaper.escape(section): {
|
||||
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
|
||||
for param in params
|
||||
}
|
||||
for section, params in sections.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configurations(
|
||||
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
only = ["id"]
|
||||
if names:
|
||||
only.extend(
|
||||
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
|
||||
)
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
task.id: {
|
||||
"configuration": [
|
||||
c.to_proper_dict()
|
||||
for c in sorted(task.configuration.values(), key=attrgetter("name"))
|
||||
]
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
) -> Dict[str, list]:
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
@@ -1,11 +1,10 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apierrors import errors
|
||||
from bll.task import ChangeStatusRequest
|
||||
from config import config
|
||||
from database.model.task.task import TaskStatus, Task
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from apiserver.bll.task import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import TaskStatus, Task
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -71,19 +70,29 @@ class NonResponsiveTasksWatchdog:
|
||||
return 0
|
||||
|
||||
err_count = 0
|
||||
project_ids = set()
|
||||
now = datetime.utcnow()
|
||||
for task in tasks:
|
||||
log.info(
|
||||
f"Stopping {task.id} ({task.name}), last updated at {task.last_update}"
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.stopped,
|
||||
updated = Task.objects(id=task.id, status=task.status).update(
|
||||
status=TaskStatus.stopped,
|
||||
status_reason="Forced stop (non-responsive)",
|
||||
status_message="Forced stop (non-responsive)",
|
||||
force=True,
|
||||
).execute()
|
||||
except errors.bad_request.FailedChangingTaskStatus:
|
||||
err_count += 1
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
if updated:
|
||||
project_ids.add(task.project)
|
||||
else:
|
||||
err_count += 1
|
||||
except Exception as ex:
|
||||
log.error("Failed setting status: %s", str(ex))
|
||||
|
||||
update_project_time(list(project_ids))
|
||||
|
||||
return len(tasks) - err_count
|
||||
201
apiserver/bll/task/param_utils.py
Normal file
201
apiserver/bll/task/param_utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import dpath
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
hyperparams_default_section = "Args"
|
||||
hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
if default_section is None:
|
||||
return None, full_name
|
||||
|
||||
section, _, name = full_name.partition("/")
|
||||
if section != tf_define_section:
|
||||
return default_section, full_name
|
||||
|
||||
if not name:
|
||||
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
|
||||
return section, name
|
||||
|
||||
|
||||
def _get_full_param_name(param: dict) -> str:
|
||||
section = param.get("section")
|
||||
if section != tf_define_section:
|
||||
return param["name"]
|
||||
|
||||
return "/".join((section, param["name"]))
|
||||
|
||||
|
||||
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
removed = 0
|
||||
if not data:
|
||||
return removed
|
||||
|
||||
if with_sections:
|
||||
for section, section_data in list(data.items()):
|
||||
removed += _remove_legacy_params(section_data)
|
||||
if not section_data:
|
||||
"""If section is empty after removing legacy params then delete it"""
|
||||
del data[section]
|
||||
else:
|
||||
for key, param in list(data.items()):
|
||||
if param.get("type") == hyperparams_legacy_type:
|
||||
removed += 1
|
||||
del data[key]
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
|
||||
return [
|
||||
param for param in data.values() if param.get("type") == hyperparams_legacy_type
|
||||
]
|
||||
|
||||
|
||||
def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
"""
|
||||
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
previous_data = previous_task.to_proper_dict().get(new_params_field)
|
||||
removed = _remove_legacy_params(
|
||||
previous_data, with_sections=default_section is not None
|
||||
)
|
||||
if not legacy_params and not removed:
|
||||
# if we only need to delete legacy fields from the db
|
||||
# but they are not there then there is no point to proceed
|
||||
continue
|
||||
|
||||
fields_update = {new_params_field: previous_data}
|
||||
params_unprepare_from_saved(fields_update)
|
||||
fields.update(fields_update)
|
||||
|
||||
for full_name, value in legacy_params.items():
|
||||
section, name = split_param_name(full_name, default_section)
|
||||
new_path = list(filter(None, (new_params_field, section, name)))
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
"""
|
||||
Unescape all section and param names for hyper params and configuration
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
)
|
||||
|
||||
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
for prefix in (
|
||||
"hyperparams.",
|
||||
"-hyperparams.",
|
||||
"configuration.",
|
||||
"-configuration.",
|
||||
):
|
||||
paths = [
|
||||
_process_path(path) if path.startswith(prefix) else path for path in paths
|
||||
]
|
||||
return paths
|
||||
@@ -1,47 +1,49 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from random import random
|
||||
from time import sleep
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import pymongo.results
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from six import string_types
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.metrics import EventStats, MetricEventStats
|
||||
from database.model.task.output import Output
|
||||
from database.model.task.task import (
|
||||
import apiserver.database.utils as dbutils
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.metrics import EventStats, MetricEventStats
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
external_task_types,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from services.utils import validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
queue_bll = QueueBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
class TaskBLL:
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
@@ -83,25 +85,24 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
required_dataset=None,
|
||||
only_fields=None,
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
only_fields = [only_fields]
|
||||
else:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
qs = Task.objects(id=task_id, company=company_id)
|
||||
if only_fields:
|
||||
qs = (
|
||||
qs.only(only_fields)
|
||||
if isinstance(only_fields, string_types)
|
||||
else qs.only(*only_fields)
|
||||
)
|
||||
qs = qs.only(
|
||||
"status", "input"
|
||||
) # make sure all fields we rely on here are also returned
|
||||
task = qs.first()
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -109,17 +110,12 @@ class TaskBLL(object):
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
if required_dataset and required_dataset not in (
|
||||
entry.dataset for entry in task.input.view.entries
|
||||
):
|
||||
raise errors.bad_request.InvalidId(
|
||||
"not in input view", dataset=required_dataset
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(company_id, task_ids, only=None, allow_public=False):
|
||||
def assert_exists(
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
ids = set(task_ids)
|
||||
@@ -130,14 +126,15 @@ class TaskBLL(object):
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
res = q.only(*only)
|
||||
count = len(res)
|
||||
else:
|
||||
count = q.count()
|
||||
res = q.first()
|
||||
if count != len(ids):
|
||||
# Make sure to reset fields filters (some fields are excluded by default) since this
|
||||
# is an internal call and specific fields were requested.
|
||||
q = q.all_fields().only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
return res
|
||||
|
||||
if return_tasks:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
@@ -149,6 +146,7 @@ class TaskBLL(object):
|
||||
company=identity.company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
**fields,
|
||||
)
|
||||
|
||||
@@ -170,57 +168,103 @@ class TaskBLL(object):
|
||||
@classmethod
|
||||
def clone_task(
|
||||
cls,
|
||||
company_id,
|
||||
user_id,
|
||||
task_id,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
name: Optional[str] = None,
|
||||
comment: Optional[str] = None,
|
||||
parent: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
new_project_name: str = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
validate_tags(tags, system_tags)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
("hyperparams", hyperparams),
|
||||
("configuration", configuration),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
execution_overrides["parameters"] = {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
|
||||
execution_dict.update(execution_overrides)
|
||||
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
a for a in artifacts if a.get("mode") != ArtifactModes.output
|
||||
]
|
||||
execution_dict["artifacts"] = {
|
||||
k: a
|
||||
for k, a in artifacts.items()
|
||||
if a.get("mode") != ArtifactModes.output
|
||||
}
|
||||
execution_dict.pop("queue", None)
|
||||
|
||||
new_project_data = None
|
||||
if not project and new_project_name:
|
||||
# Use a project with the provided name, or create a new project
|
||||
project = ProjectBLL.find_or_create(
|
||||
project_name=new_project_name,
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
description="Auto-generated while cloning",
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
with translate_errors_context():
|
||||
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
|
||||
if not input_tags:
|
||||
return input_tags
|
||||
|
||||
return [
|
||||
tag
|
||||
for tag in input_tags
|
||||
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or [],
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
@@ -243,8 +287,9 @@ class TaskBLL(object):
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
|
||||
return new_task
|
||||
return new_task, new_project_data
|
||||
|
||||
@classmethod
|
||||
def validate(
|
||||
@@ -254,6 +299,11 @@ class TaskBLL(object):
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
"""
|
||||
Validate task properties according to the flag
|
||||
Task project is always checked for being writable
|
||||
in order to disable the modification of public projects
|
||||
"""
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
@@ -263,12 +313,10 @@ class TaskBLL(object):
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
if task.project:
|
||||
project = Project.get_for_writing(company=task.company, id=task.project)
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
@@ -278,7 +326,7 @@ class TaskBLL(object):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company=company_id,
|
||||
company={"$in": [None, "", company_id]},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
)
|
||||
},
|
||||
@@ -316,11 +364,29 @@ class TaskBLL(object):
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str], company_id: str, last_update: datetime
|
||||
task_ids: Collection[str],
|
||||
company_id: str,
|
||||
last_update: datetime,
|
||||
**extra_updates,
|
||||
):
|
||||
return Task.objects(id__in=task_ids, company=company_id).update(
|
||||
upsert=False, last_update=last_update
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
updates = {
|
||||
"active_duration": (
|
||||
datetime.utcnow() - task.started
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
**updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
@@ -383,8 +449,11 @@ class TaskBLL(object):
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
Task.objects(id=task_id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **extra_updates
|
||||
TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -534,120 +603,35 @@ class TaskBLL(object):
|
||||
force=force,
|
||||
).execute()
|
||||
|
||||
@classmethod
|
||||
def add_or_update_artifacts(
|
||||
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
key = attrgetter("key", "mode")
|
||||
|
||||
if not artifacts:
|
||||
return [], []
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
|
||||
artifacts: List[Artifact] = [
|
||||
Artifact(**artifact.to_struct()) for artifact in artifacts
|
||||
]
|
||||
|
||||
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
|
||||
|
||||
for retry in range(attempts):
|
||||
task = cls.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
current = list(map(key, task.execution.artifacts))
|
||||
updated = [a for a in artifacts if key(a) in current]
|
||||
added = [a for a in artifacts if a not in updated]
|
||||
|
||||
filter = {"_id": task_id, "company": company_id}
|
||||
update = {}
|
||||
array_filters = None
|
||||
if current:
|
||||
filter["execution.artifacts"] = {
|
||||
"$size": len(current),
|
||||
"$all": [
|
||||
*(
|
||||
{"$elemMatch": {"key": key, "mode": mode}}
|
||||
for key, mode in current
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
filter["$or"] = [
|
||||
{"execution.artifacts": {"$exists": False}},
|
||||
{"execution.artifacts": {"$size": 0}},
|
||||
]
|
||||
|
||||
if added:
|
||||
update["$push"] = {
|
||||
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
|
||||
}
|
||||
if updated:
|
||||
update["$set"] = {
|
||||
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
|
||||
for index, artifact in enumerate(updated)
|
||||
}
|
||||
array_filters = [
|
||||
{
|
||||
f"artifact{index}.key": artifact.key,
|
||||
f"artifact{index}.mode": artifact.mode,
|
||||
}
|
||||
for index, artifact in enumerate(updated)
|
||||
]
|
||||
|
||||
if not update:
|
||||
return [], []
|
||||
|
||||
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
|
||||
filter=filter,
|
||||
update=update,
|
||||
array_filters=array_filters,
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if result.matched_count >= 1:
|
||||
break
|
||||
|
||||
wait_msec = random() * int(
|
||||
config.get("services.tasks.artifacts.update_retry_msec", 500)
|
||||
)
|
||||
|
||||
log.warning(
|
||||
f"Failed to update artifacts for task {task_id} (updated by another party),"
|
||||
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
|
||||
)
|
||||
|
||||
sleep(wait_msec / 1000)
|
||||
else:
|
||||
raise errors.server_error.UpdateFailed(
|
||||
"task artifacts updated by another party"
|
||||
)
|
||||
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_execution_parameters(
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[str]]:
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"execution.parameters": {"$exists": True, "$gt": {}},
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
|
||||
{"$unwind": "$parameters"},
|
||||
{"$group": {"_id": "$parameters.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
@@ -673,9 +657,59 @@ class TaskBLL(object):
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r["_id"])
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
):
|
||||
cls.dequeue(task, company_id)
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(unset__execution__queue=1)
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
|
||||
"""
|
||||
Dequeue the task from the queue
|
||||
:param task: task to dequeue
|
||||
:param company_id: task's company ID.
|
||||
:param silent_fail: do not throw exceptions. APIError is still thrown
|
||||
:raise errors.bad_request.InvalidTaskId: if the task's status is not queued
|
||||
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
||||
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
||||
:return: the result of queues.remove_task call. None in case of silent failure
|
||||
"""
|
||||
if task.status not in (TaskStatus.queued,):
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
status=task.status, expected=TaskStatus.queued
|
||||
)
|
||||
|
||||
if not task.execution or not task.execution.queue:
|
||||
if silent_fail:
|
||||
return
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"task has no queue value", field="execution.queue"
|
||||
)
|
||||
|
||||
return {
|
||||
"removed": queue_bll.remove_task(
|
||||
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
||||
)
|
||||
}
|
||||
@@ -1,17 +1,16 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Callable, Tuple, Sequence
|
||||
from typing import TypeVar, Callable, Tuple, Sequence, Union
|
||||
|
||||
import attr
|
||||
import six
|
||||
from boltons.dictutils import OneToOne
|
||||
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from database.utils import get_options
|
||||
from timing_context import TimingContext
|
||||
from utilities.attrs import typed_attrs
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
|
||||
@@ -44,6 +43,7 @@ class ChangeStatusRequest(object):
|
||||
status_message=self.status_message,
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
|
||||
if self.new_status == TaskStatus.queued:
|
||||
@@ -153,9 +153,14 @@ def get_possible_status_changes(current_status):
|
||||
return possible
|
||||
|
||||
|
||||
def update_project_time(project_id):
|
||||
if project_id:
|
||||
Project.objects(id=project_id).update(last_update=datetime.utcnow())
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -174,24 +179,32 @@ def split_by(
|
||||
)
|
||||
|
||||
|
||||
class ParameterKeyEscaper:
|
||||
_mapping = OneToOne({".": "%2E", "$": "%24"})
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
@classmethod
|
||||
def escape(cls, value):
|
||||
""" Quote a parameter key """
|
||||
value = value.strip().replace("%", "%%")
|
||||
for c, r in cls._mapping.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def _unescape(cls, value):
|
||||
for c, r in cls._mapping.inv.items():
|
||||
value = value.replace(c, r)
|
||||
return value
|
||||
allowed_statuses = (
|
||||
[TaskStatus.created, TaskStatus.in_progress] if force else [TaskStatus.created]
|
||||
)
|
||||
if task.status not in allowed_statuses:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def unescape(cls, value):
|
||||
""" Unquote a quoted parameter key """
|
||||
return "%".join(map(cls._unescape, value.split("%%")))
|
||||
|
||||
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
|
||||
now = datetime.utcnow()
|
||||
last_updates = dict(last_change=now)
|
||||
if set_last_update:
|
||||
last_updates.update(last_update=now)
|
||||
return task.update(**update_cmds, **last_updates)
|
||||
@@ -1,7 +1,7 @@
|
||||
from apierrors import errors
|
||||
from apimodels.users import CreateRequest
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.user import User
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.users import CreateRequest
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class UserBLL:
|
||||
@@ -1,9 +1,13 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
|
||||
|
||||
from database.model import AttributedDocument
|
||||
from database.model.settings import Settings
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.settings import Settings
|
||||
|
||||
|
||||
def extract_properties_to_lists(
|
||||
@@ -35,14 +39,21 @@ class SetFieldsResolver:
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = set_fields
|
||||
self.fields = {
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
@@ -71,3 +82,36 @@ class SetFieldsResolver:
|
||||
@functools.lru_cache()
|
||||
def get_server_uuid() -> Optional[str]:
|
||||
return Settings.get_by_key("server.uuid")
|
||||
|
||||
|
||||
def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
|
||||
"""
|
||||
Decorates a method for parallel chunked execution. The method should have
|
||||
one positional parameter (that is used for breaking into chunks)
|
||||
and arbitrary number of keyword params. The return value should be iterable
|
||||
The results are concatenated in the same order as the passed params
|
||||
"""
|
||||
if func is None:
|
||||
return functools.partial(parallel_chunked_decorator, chunk_size=chunk_size)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, iterable: Iterable, **kwargs):
|
||||
assert iterutils.is_collection(
|
||||
iterable
|
||||
), "The positional parameter should be an iterable for breaking into chunks"
|
||||
|
||||
func_with_params = functools.partial(func, self, **kwargs)
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
filter(
|
||||
None,
|
||||
pool.map(
|
||||
func_with_params,
|
||||
iterutils.chunked_iter(iterable, chunk_size),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
@@ -5,10 +5,10 @@ from typing import Sequence, Set, Optional
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
|
||||
import es_factory
|
||||
from apierrors import APIError
|
||||
from apierrors.errors import bad_request, server_error
|
||||
from apimodels.workers import (
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.errors import bad_request, server_error
|
||||
from apiserver.apimodels.workers import (
|
||||
DEFAULT_TIMEOUT,
|
||||
IdNameEntry,
|
||||
WorkerEntry,
|
||||
@@ -17,15 +17,16 @@ from apimodels.workers import (
|
||||
QueueEntry,
|
||||
MachineStats,
|
||||
)
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -49,6 +50,7 @@ class WorkerBLL:
|
||||
ip: str = "",
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -58,6 +60,7 @@ class WorkerBLL:
|
||||
:param ip: the real ip of the worker
|
||||
:param queues: queues reported as being monitored by the worker
|
||||
:param timeout: registration expiration timeout in seconds
|
||||
:param tags: a list of tags for this worker
|
||||
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -91,6 +94,7 @@ class WorkerBLL:
|
||||
register_time=now,
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
@@ -113,12 +117,15 @@ class WorkerBLL:
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
:param company_id: worker's company ID
|
||||
:param user_id: user_id ID under which this worker is running
|
||||
:param ip: worker IP
|
||||
:param report: the report itself
|
||||
:param tags: tags for this worker
|
||||
:raise bad_request.InvalidTaskId: the reported task was not found
|
||||
:return: worker entry instance
|
||||
"""
|
||||
@@ -129,6 +136,9 @@ class WorkerBLL:
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
@@ -146,6 +156,7 @@ class WorkerBLL:
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
@@ -153,6 +164,7 @@ class WorkerBLL:
|
||||
last_worker=report.worker,
|
||||
last_worker_report=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
# modify(new=True, ...) returns the modified object
|
||||
task = Task.objects(**query).modify(new=True, **update)
|
||||
@@ -160,6 +172,12 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
@@ -369,7 +387,6 @@ class WorkerBLL:
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type="stat",
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
@@ -3,12 +3,12 @@ from typing import Optional, Sequence
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
|
||||
from apierrors.errors import bad_request
|
||||
from apimodels.workers import AggregationType, GetStatsRequest, StatItem
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
from apiserver.apierrors.errors import bad_request
|
||||
from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatItem
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -25,7 +25,6 @@ class WorkerStats:
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ class WorkerStats:
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]:
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
@@ -87,7 +86,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{request.interval}s",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
@@ -216,7 +215,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
1
apiserver/config/__init__.py
Normal file
1
apiserver/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .basic import BasicConfig, ConfigurationError
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import platform
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
from pathlib import Path
|
||||
from typing import List, Any, TypeVar
|
||||
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from pyparsing import (
|
||||
@@ -14,82 +16,104 @@ from pyparsing import (
|
||||
ParseSyntaxException,
|
||||
)
|
||||
|
||||
DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';'
|
||||
from apiserver.utilities import json
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
EXTRA_CONFIG_PATHS = ("/opt/trains/config",)
|
||||
EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
|
||||
def __init__(self, folder):
|
||||
self.folder = Path(folder)
|
||||
if not self.folder.is_dir():
|
||||
extra_config_values_env_key_sep = "__"
|
||||
default_config_dir = "default"
|
||||
|
||||
def __init__(
|
||||
self, folder: str = None, verbose: bool = True, prefix: str = "trains"
|
||||
):
|
||||
folder = (
|
||||
Path(folder)
|
||||
if folder
|
||||
else Path(__file__).with_name(self.default_config_dir)
|
||||
)
|
||||
if not folder.is_dir():
|
||||
raise ValueError("Invalid configuration folder")
|
||||
|
||||
self.prefix = "trains"
|
||||
self.verbose = verbose
|
||||
self.prefix = prefix
|
||||
self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__"
|
||||
|
||||
self._load()
|
||||
self._paths = [folder, *self._get_paths()]
|
||||
self._config = self._reload()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._config[key]
|
||||
|
||||
def get(self, key, default=NotSet):
|
||||
def get(self, key: str, default: Any = NotSet) -> Any:
|
||||
value = self._config.get(key, default)
|
||||
if value is self.NotSet and not default:
|
||||
if value is self.NotSet:
|
||||
raise KeyError(
|
||||
f"Unable to find value for key '{key}' and default value was not provided."
|
||||
)
|
||||
return value
|
||||
|
||||
def logger(self, name):
|
||||
def to_dict(self) -> dict:
|
||||
return self._config.as_plain_ordered_dict()
|
||||
|
||||
def as_json(self) -> str:
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
def logger(self, name: str) -> logging.Logger:
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self):
|
||||
def _read_extra_env_config_values(self) -> ConfigTree:
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
prefix = self.extra_config_values_env_key_prefix
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
path = (
|
||||
key[len(prefix) :]
|
||||
.replace(self.extra_config_values_env_key_sep, ".")
|
||||
.lower()
|
||||
)
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
return
|
||||
def _get_paths(self) -> List[Path]:
|
||||
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
|
||||
value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths)
|
||||
|
||||
paths = [
|
||||
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
|
||||
]
|
||||
invalid = [
|
||||
path
|
||||
for path in paths
|
||||
if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH
|
||||
]
|
||||
if invalid:
|
||||
print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}")
|
||||
|
||||
if value is not default_paths:
|
||||
invalid = [path for path in paths if not path.is_dir()]
|
||||
if invalid:
|
||||
print(
|
||||
f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}"
|
||||
)
|
||||
|
||||
return [path for path in paths if path.is_dir()]
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
def reload(self):
|
||||
self._config = self._reload()
|
||||
|
||||
self._config = reduce(
|
||||
def _reload(self) -> ConfigTree:
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
|
||||
configs = [self._read_recursive(path) for path in self._paths]
|
||||
|
||||
return reduce(
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
@@ -97,32 +121,31 @@ class BasicConfig:
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
def _read_recursive(self, conf_root, verbose=True):
|
||||
def _read_recursive(self, conf_root) -> ConfigTree:
|
||||
conf = ConfigTree()
|
||||
|
||||
if not conf_root:
|
||||
return conf
|
||||
|
||||
if not conf_root.is_dir():
|
||||
if verbose:
|
||||
if self.verbose:
|
||||
if not conf_root.exists():
|
||||
print(f"No config in {conf_root}")
|
||||
else:
|
||||
print(f"Not a directory: {conf_root}")
|
||||
return conf
|
||||
|
||||
if verbose:
|
||||
if self.verbose:
|
||||
print(f"Loading config from {conf_root}")
|
||||
|
||||
for file in conf_root.rglob("*.conf"):
|
||||
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
|
||||
conf.put(key, self._read_single_file(file, verbose=verbose))
|
||||
conf.put(key, self._read_single_file(file))
|
||||
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def _read_single_file(file_path, verbose=True):
|
||||
if verbose:
|
||||
def _read_single_file(self, file_path):
|
||||
if self.verbose:
|
||||
print(f"Loading config from file {file_path}")
|
||||
|
||||
try:
|
||||
@@ -137,8 +160,17 @@ class BasicConfig:
|
||||
print(f"Failed loading {file_path}: {ex}")
|
||||
raise
|
||||
|
||||
def initialize_logging(self):
|
||||
logging_config = self.get("logging", None)
|
||||
if not logging_config:
|
||||
return
|
||||
logging.config.dictConfig(logging_config)
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
def __init__(self, msg, file_path=None, *args):
|
||||
super(ConfigurationError, self).__init__(msg, *args)
|
||||
super().__init__(msg, *args)
|
||||
self.file_path = file_path
|
||||
|
||||
|
||||
ConfigType = TypeVar("ConfigType", bound=BasicConfig)
|
||||
@@ -26,6 +26,17 @@
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
# not including the pre_populate
|
||||
db_init_timout: 120
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
@@ -34,11 +45,16 @@
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
pre_populate {
|
||||
enabled: false
|
||||
zip_file: "/path/to/export.zip"
|
||||
fail_on_error: false
|
||||
elastic {
|
||||
probing {
|
||||
# settings for inital probing of elastic connection
|
||||
max_retries: 4
|
||||
timeout: 30
|
||||
}
|
||||
upgrade_monitoring {
|
||||
v16_migration_verification: true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ elastic {
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
@@ -15,7 +15,7 @@ elastic {
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
max_retries: 3
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
16
apiserver/config/default/services/auth.conf
Normal file
16
apiserver/config/default/services/auth.conf
Normal file
@@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
||||
27
apiserver/config/default/services/events.conf
Normal file
27
apiserver/config/default/services/events.conf
Normal file
@@ -0,0 +1,27 @@
|
||||
es_index_prefix: "events"
|
||||
|
||||
ignore_iteration {
|
||||
metrics: [":monitor:machine", ":monitor:gpu"]
|
||||
}
|
||||
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
# the max amount of metrics to aggregate on
|
||||
max_metrics_count: 100
|
||||
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
# and the result of the check is written to the db
|
||||
validate_plot_str: false
|
||||
|
||||
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||
plot_compression_threshold: 100000
|
||||
13
apiserver/config/default/services/projects.conf
Normal file
13
apiserver/config/default/services/projects.conf
Normal file
@@ -0,0 +1,13 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured {
|
||||
order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
||||
|
||||
# default featured index for public projects not specified in the order
|
||||
public_default: 9999
|
||||
}
|
||||
@@ -8,7 +8,4 @@ non_responsive_tasks_watchdog {
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
|
||||
artifacts {
|
||||
update_attempts: 10
|
||||
update_retry_msec: 500
|
||||
}
|
||||
multi_task_histogram_limit: 100
|
||||
@@ -1,9 +1,9 @@
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
from version import __version__
|
||||
|
||||
from config import config
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.version import __version__
|
||||
|
||||
root = Path(__file__).parent.parent
|
||||
|
||||
@@ -41,3 +41,7 @@ def get_deployment_type() -> str:
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
es_connection_error = False
|
||||
4
apiserver/config_repo.py
Normal file
4
apiserver/config_repo.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from apiserver.config import BasicConfig
|
||||
|
||||
config = BasicConfig()
|
||||
config.initialize_logging()
|
||||
98
apiserver/database/__init__.py
Normal file
98
apiserver/database/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from os import getenv
|
||||
|
||||
from boltons.iterutils import first
|
||||
from furl import furl
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
from jsonmodels.fields import StringField
|
||||
from mongoengine import register_connection
|
||||
from mongoengine.connection import get_connection, disconnect
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from .defs import Database
|
||||
from .utils import get_items
|
||||
|
||||
log = config.logger("database")
|
||||
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"TRAINS_MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
alias = StringField()
|
||||
|
||||
|
||||
class DatabaseFactory:
|
||||
_entries = []
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
|
||||
cls._entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
|
||||
|
||||
@classmethod
|
||||
def get_entries(cls):
|
||||
return cls._entries
|
||||
|
||||
@classmethod
|
||||
def get_hosts(cls):
|
||||
return [entry.host for entry in cls.get_entries()]
|
||||
|
||||
@classmethod
|
||||
def get_aliases(cls):
|
||||
return [entry.alias for entry in cls.get_entries()]
|
||||
|
||||
@classmethod
|
||||
def reconnect(cls):
|
||||
for entry in cls.get_entries():
|
||||
# there is bug in the current implementation that prevents
|
||||
# reconnection from work so workaround this
|
||||
# get_connection(entry.alias, reconnect=True)
|
||||
disconnect(entry.alias)
|
||||
register_connection(alias=entry.alias, host=entry.host)
|
||||
get_connection(entry.alias)
|
||||
|
||||
|
||||
db = DatabaseFactory()
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from textwrap import shorten
|
||||
|
||||
import dpath
|
||||
from dpath.exceptions import InvalidKeyName
|
||||
@@ -17,7 +18,7 @@ from mongoengine.errors import (
|
||||
)
|
||||
from pymongo.errors import PyMongoError, NotMasterError
|
||||
|
||||
from apierrors import errors
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MakeGetAllQueryError(Exception):
|
||||
@@ -33,7 +34,7 @@ class ParseCallError(Exception):
|
||||
self.params = kwargs
|
||||
|
||||
|
||||
def throws_default_error(err_cls):
|
||||
def throws_default_error(err_cls, shorten_width: int = None):
|
||||
"""
|
||||
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
|
||||
messages in ``err_cls``. If the decorated function does not find a suitable error message,
|
||||
@@ -45,25 +46,49 @@ def throws_default_error(err_cls):
|
||||
@wraps(func)
|
||||
def wrapper(self, e, message, **kwargs):
|
||||
extra_info = func(self, e, message, **kwargs)
|
||||
raise err_cls(message, err=e, extra_info=extra_info)
|
||||
err = str(e)
|
||||
if shorten_width:
|
||||
err = shorten(err, shorten_width, placeholder="...")
|
||||
raise err_cls(message, err=err, extra_info=extra_info)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# noinspection RegExpRedundantEscape
|
||||
class ElasticErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def _bulk_meta_error(cls, error):
|
||||
try:
|
||||
_, err_type = next(dpath.search(error, "*/error/type", yielded=True))
|
||||
_, reason = next(dpath.search(error, "*/error/reason", yielded=True))
|
||||
if err_type == "cluster_block_exception":
|
||||
raise errors.server_error.LowDiskSpace(
|
||||
"metrics, logs and all indexed data is in read-only mode!",
|
||||
reason=re.sub(r"^index\s\[.*?\]\s", "", reason) if reason else ""
|
||||
)
|
||||
return
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError, shorten_width=200)
|
||||
def bulk_error(cls, e, _, **__):
|
||||
if not e.errors:
|
||||
return
|
||||
|
||||
# Currently we only handle the first error
|
||||
error = e.errors[0]
|
||||
|
||||
cls._bulk_meta_error(error)
|
||||
|
||||
# Else try returning a better error string
|
||||
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
|
||||
return reason
|
||||
|
||||
|
||||
# noinspection RegExpRedundantEscape
|
||||
class MongoEngineErrorsHandler(object):
|
||||
# NotUniqueError
|
||||
__not_unique_regex = re.compile(
|
||||
@@ -81,6 +106,7 @@ class MongoEngineErrorsHandler(object):
|
||||
def validation_error(cls, e: ValidationError, message, **_):
|
||||
# Thrown when a document is validated. Documents are validated by default on save and on update
|
||||
err_dict = e.errors or {e.field_name: e.message}
|
||||
err_dict = {key: str(value) for key, value in err_dict.items()}
|
||||
raise errors.bad_request.DataValidationError(message, **err_dict)
|
||||
|
||||
@classmethod
|
||||
@@ -14,7 +14,7 @@ from mongoengine import (
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar, EmailField
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
@@ -93,6 +93,24 @@ class CustomFloatField(FloatField):
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
class CanonicEmailField(EmailField):
|
||||
"""email field that is always lower cased"""
|
||||
def __set__(self, instance, value: str):
|
||||
if value is not None:
|
||||
try:
|
||||
value = value.lower()
|
||||
except AttributeError:
|
||||
pass
|
||||
super().__set__(instance, value)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
if value is not None:
|
||||
value = value.lower()
|
||||
return super().prepare_query_value(op, value)
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
def __init__(
|
||||
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
|
||||
@@ -2,10 +2,10 @@ from enum import Enum
|
||||
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class AttributedDocument(DbModelMixin, Document):
|
||||
@@ -6,10 +6,10 @@ from mongoengine import (
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import AuthDocument
|
||||
from database.utils import get_options
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import AuthDocument
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class Entities(object):
|
||||
@@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
@@ -70,5 +72,5 @@ class User(DbModelMixin, AuthDocument):
|
||||
credentials = EmbeddedDocumentListField(Credentials, default=list)
|
||||
""" Credentials generated for this user """
|
||||
|
||||
email = EmailField(unique=True, required=True)
|
||||
email = EmailField(unique=True, sparse=True)
|
||||
""" Email uniquely identifying the user """
|
||||
@@ -1,20 +1,21 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first, bucketize
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
from database.props import PropsMixin
|
||||
from database.query import RegexQ, RegexWrapper
|
||||
from database.utils import (
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
from apiserver.database.query import RegexQ, RegexWrapper
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_choices,
|
||||
field_does_not_exist,
|
||||
@@ -103,8 +104,13 @@ class GetMixin(PropsMixin):
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_ops = {
|
||||
"not": ("nin", False),
|
||||
"all": ("all", True),
|
||||
"and": ("all", True),
|
||||
}
|
||||
_next = _default
|
||||
_sticky = False
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
@@ -115,13 +121,16 @@ class GetMixin(PropsMixin):
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
return self._ops["not"][0]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
self._next, self._sticky = self._ops.get(
|
||||
v[len(self.op_prefix) :], (self._default, self._sticky)
|
||||
)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
if not self._sticky:
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
@@ -259,6 +268,7 @@ class GetMixin(PropsMixin):
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
@@ -347,6 +357,20 @@ class GetMixin(PropsMixin):
|
||||
return []
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
@@ -483,10 +507,25 @@ class GetMixin(PropsMixin):
|
||||
query=_query, parameters=parameters, override_projection=override_projection
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query,
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
@@ -509,7 +548,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
@@ -517,13 +558,14 @@ class GetMixin(PropsMixin):
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
if only:
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
@@ -559,7 +601,9 @@ class GetMixin(PropsMixin):
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
@@ -596,16 +640,15 @@ class GetMixin(PropsMixin):
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if only:
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*only) for qs in query_sets]
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields())
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
@@ -616,7 +659,8 @@ class GetMixin(PropsMixin):
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
@@ -657,14 +701,24 @@ class GetMixin(PropsMixin):
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
__user_set_allowed_fields = None
|
||||
__locked_when_published_fields = None
|
||||
|
||||
@classmethod
|
||||
def user_set_allowed(cls):
|
||||
res = getattr(cls, "__user_set_allowed_fields", None)
|
||||
if res is None:
|
||||
res = cls.__user_set_allowed_fields = get_fields_choices(
|
||||
cls, "user_set_allowed"
|
||||
if cls.__user_set_allowed_fields is None:
|
||||
cls.__user_set_allowed_fields = dict(
|
||||
get_fields_choices(cls, "user_set_allowed")
|
||||
)
|
||||
return res
|
||||
return cls.__user_set_allowed_fields
|
||||
|
||||
@classmethod
|
||||
def locked_when_published(cls):
|
||||
if cls.__locked_when_published_fields is None:
|
||||
cls.__locked_when_published_fields = dict(
|
||||
get_fields_choices(cls, "locked_when_published")
|
||||
)
|
||||
return cls.__locked_when_published_fields
|
||||
|
||||
@classmethod
|
||||
def get_safe_update_dict(cls, fields):
|
||||
@@ -728,6 +782,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, set__company="")
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
@@ -8,9 +8,9 @@ from mongoengine import (
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import DbModelMixin
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
|
||||
|
||||
class ReportStatsOption(EmbeddedDocument):
|
||||
@@ -29,7 +29,7 @@ class Company(DbModelMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(unique=True, min_length=3)
|
||||
name = StrippedStringField(min_length=3)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults)
|
||||
|
||||
@classmethod
|
||||
@@ -1,14 +1,14 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
from database.model.user import User
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
@@ -19,6 +19,7 @@ class Model(DbModelMixin, Document):
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
@@ -71,3 +72,4 @@ class Model(DbModelMixin, Document):
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
@@ -1,4 +1,4 @@
|
||||
from database.fields import NoneType, UnionField, SafeMapField
|
||||
from apiserver.database.fields import NoneType, UnionField, SafeMapField
|
||||
|
||||
|
||||
class ModelLabels(SafeMapField):
|
||||
@@ -1,9 +1,9 @@
|
||||
from mongoengine import StringField, DateTimeField
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
|
||||
|
||||
class Project(AttributedDocument):
|
||||
@@ -40,3 +40,7 @@ class Project(AttributedDocument):
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
@@ -6,12 +6,12 @@ from mongoengine import (
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
from database.model.task.task import Task
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
@@ -3,8 +3,8 @@ from typing import Any, Optional, Sequence, Tuple
|
||||
from mongoengine import Document, StringField, DynamicField, Q
|
||||
from mongoengine.errors import NotUniqueError
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
|
||||
|
||||
class SettingKeys:
|
||||
@@ -6,7 +6,7 @@ from mongoengine import (
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from database.fields import SafeMapField
|
||||
from apiserver.database.fields import SafeMapField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import EmbeddedDocument, StringField
|
||||
|
||||
from database.fields import StrippedStringField
|
||||
from database.utils import get_options
|
||||
from apiserver.database.fields import StrippedStringField
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
class Result(object):
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
@@ -8,20 +10,19 @@ from mongoengine import (
|
||||
LongField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import (
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeMapField,
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.utils import get_options
|
||||
from .metrics import MetricEvent, MetricEventStats
|
||||
from .output import Output
|
||||
|
||||
@@ -49,14 +50,14 @@ class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(required=True)
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
working_dir = StringField()
|
||||
class Script(EmbeddedDocument, ProperDictMixin):
|
||||
binary = StringField(default="python", strip=True)
|
||||
repository = StringField(default="", strip=True)
|
||||
tag = StringField(strip=True)
|
||||
branch = StringField(strip=True)
|
||||
version_num = StringField(strip=True)
|
||||
entry_point = StringField(default="", strip=True)
|
||||
working_dir = StringField(strip=True)
|
||||
requirements = SafeDictField()
|
||||
diff = StringField()
|
||||
|
||||
@@ -72,10 +73,13 @@ class ArtifactModes:
|
||||
output = "output"
|
||||
|
||||
|
||||
DEFAULT_ARTIFACT_MODE = ArtifactModes.output
|
||||
|
||||
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
@@ -84,14 +88,30 @@ class Artifact(EmbeddedDocument):
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
section = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
artifacts = EmbeddedDocumentSortedListField(Artifact)
|
||||
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
""" Queue ID where task was queued """
|
||||
@@ -115,9 +135,12 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
|
||||
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -127,10 +150,13 @@ class Task(AttributedDocument):
|
||||
"created",
|
||||
"started",
|
||||
"completed",
|
||||
"active_duration",
|
||||
"parent",
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "status", "type"),
|
||||
("company", "system_tags", "last_update"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
@@ -159,10 +185,9 @@ class Task(AttributedDocument):
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -180,16 +205,34 @@ class Task(AttributedDocument):
|
||||
started = DateTimeField()
|
||||
completed = DateTimeField()
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
active_duration = IntField(default=None)
|
||||
parent = StringField(reference_field="Task")
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script)
|
||||
script: Script = EmbeddedDocumentField(Script, default=Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_change = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
docker_init_script = StringField()
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
"""
|
||||
Returns the company ID used for locating indices containing task data.
|
||||
In case the task has a valid company, this is the company ID.
|
||||
Otherwise, if the task has a company_origin, this is a task that has been made public and the
|
||||
origin company should be used.
|
||||
Otherwise, an empty company is used.
|
||||
"""
|
||||
return self.company or self.company_origin or ""
|
||||
@@ -1,9 +1,9 @@
|
||||
from mongoengine import Document, StringField, DynamicField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.company import Company
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import Document, DateTimeField, StringField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
|
||||
|
||||
class Version(DbModelMixin, Document):
|
||||
@@ -5,8 +5,8 @@ from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
|
||||
import dpath.path
|
||||
|
||||
from apierrors import errors
|
||||
from database.props import PropsMixin
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.props import PropsMixin
|
||||
|
||||
SEP = "."
|
||||
|
||||
@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1:], s, d)
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
@@ -96,6 +96,7 @@ class _ProxyManager:
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
|
||||
orig_field = field
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection).union({"id"})
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [
|
||||
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(
|
||||
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
@@ -8,12 +8,12 @@ import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document, BaseField
|
||||
|
||||
from database.fields import (
|
||||
from apiserver.database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
)
|
||||
from database.utils import get_fields, get_fields_attr
|
||||
from apiserver.database.utils import get_fields, get_fields_attr
|
||||
|
||||
|
||||
class PropsMixin(object):
|
||||
58
apiserver/elastic/apply_mappings.py
Executable file
58
apiserver/elastic/apply_mappings.py
Executable file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
def _send_template(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
if key:
|
||||
files = (p / key).glob("*.json")
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
nargs="+",
|
||||
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(">>>>> Applying mapping to " + str(args.hosts))
|
||||
res = apply_mappings_to_cluster(args.hosts, args.key)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
113
apiserver/elastic/initialize.py
Normal file
113
apiserver/elastic/initialize.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import logging
|
||||
from time import sleep
|
||||
from typing import Type, Optional, Sequence, Any, Union
|
||||
|
||||
import urllib3.exceptions
|
||||
from elasticsearch import Elasticsearch, exceptions
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.elastic.apply_mappings import apply_mappings_to_cluster
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ElasticConnectionError(Exception):
|
||||
"""
|
||||
Exception when could not connect to elastic during init
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionErrorFilter(logging.Filter):
|
||||
def __init__(
|
||||
self,
|
||||
level: Optional[Union[int, str]] = None,
|
||||
err_type: Optional[Type] = None,
|
||||
args_prefix: Optional[Sequence[Any]] = None,
|
||||
):
|
||||
super(ConnectionErrorFilter, self).__init__()
|
||||
if level is None:
|
||||
self.level = None
|
||||
else:
|
||||
try:
|
||||
self.level = int(level)
|
||||
except ValueError:
|
||||
self.level = logging.getLevelName(level)
|
||||
|
||||
self.err_type = err_type
|
||||
self.args = args_prefix and tuple(args_prefix)
|
||||
self.last_blocked = None
|
||||
|
||||
def filter(self, record):
|
||||
try:
|
||||
filter_out = (
|
||||
(self.err_type is None or record.exc_info[0] == self.err_type)
|
||||
and (self.level is None or record.levelno == self.level)
|
||||
and (self.args is None or record.args[: len(self.args)] == self.args)
|
||||
)
|
||||
if filter_out:
|
||||
self.last_blocked = record
|
||||
return not filter_out
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def check_elastic_empty() -> bool:
|
||||
"""
|
||||
Check for elasticsearch connection
|
||||
Use probing settings and not the default es cluster ones
|
||||
so that we can handle correctly the connection rejects due to ES not fully started yet
|
||||
:return:
|
||||
"""
|
||||
cluster_conf = es_factory.get_cluster_config("events")
|
||||
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
|
||||
timeout = config.get("apiserver.elastic.probing.timeout", 30)
|
||||
|
||||
es_logger = logging.getLogger("elasticsearch")
|
||||
log_filter = ConnectionErrorFilter(
|
||||
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
|
||||
)
|
||||
|
||||
try:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
return True
|
||||
except exceptions.ConnectionError as ex:
|
||||
if retry >= max_retries - 1:
|
||||
raise ElasticConnectionError(
|
||||
f"Error connecting to Elasticsearch: {str(ex)}"
|
||||
)
|
||||
log.warn(
|
||||
f"Could not connect to ElasticSearch Service. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
|
||||
)
|
||||
sleep(timeout)
|
||||
finally:
|
||||
es_logger.removeFilter(log_filter)
|
||||
|
||||
|
||||
def init_es_data():
|
||||
for name in es_factory.get_all_cluster_names():
|
||||
cluster_conf = es_factory.get_cluster_config(name)
|
||||
hosts_config = cluster_conf.get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration(f"for cluster '{name}'")
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
log.info(res)
|
||||
40
apiserver/elastic/mappings/events/events.json
Normal file
40
apiserver/elastic/mappings/events/events.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user