mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
192 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b6f24b24d | ||
|
|
d03a931d84 | ||
|
|
5cc7199661 | ||
|
|
6537e9ef69 | ||
|
|
930aaff791 | ||
|
|
1999fb2479 | ||
|
|
9db14cc31d | ||
|
|
e3cc689528 | ||
|
|
9e0adc77dd | ||
|
|
58d9a64537 | ||
|
|
d397d2ae20 | ||
|
|
2d711e1500 | ||
|
|
97992b0d9e | ||
|
|
bc23f1b0cf | ||
|
|
6b3eff1426 | ||
|
|
caaf801cd0 | ||
|
|
c23e8a90d0 | ||
|
|
fa5b28ca0e | ||
|
|
bfb55a9463 | ||
|
|
37e485e1f2 | ||
|
|
3451ff441f | ||
|
|
53c9b5525e | ||
|
|
e5230edac3 | ||
|
|
a54dd8030c | ||
|
|
482a5c34bc | ||
|
|
ee2a72c70f | ||
|
|
a0d8aaf3b9 | ||
|
|
de1f823213 | ||
|
|
0c9e2f92ee | ||
|
|
6c49e96ff0 | ||
|
|
81e3fc6577 | ||
|
|
e6dc4b7557 | ||
|
|
238a47a197 | ||
|
|
04e7076628 | ||
|
|
0531612bf4 | ||
|
|
3ae410a1e9 | ||
|
|
98ed3075dd | ||
|
|
b871bf4224 | ||
|
|
8d4c02fc3c | ||
|
|
b986980c75 | ||
|
|
a4fa567be2 | ||
|
|
ddb91f226a | ||
|
|
7772f47773 | ||
|
|
9c118d14e0 | ||
|
|
efd56e085e | ||
|
|
4dff163af4 | ||
|
|
242a78a0fe | ||
|
|
78989fea91 | ||
|
|
5de7c12062 | ||
|
|
3f79c19079 | ||
|
|
fe29743c54 | ||
|
|
d760cf5835 | ||
|
|
3695f25a5f | ||
|
|
c6f1beafdd | ||
|
|
68a54c34f3 | ||
|
|
ab495ae586 | ||
|
|
b058770af1 | ||
|
|
f7e833bf6f | ||
|
|
36b9ab0453 | ||
|
|
ec0436d0da | ||
|
|
0f6c4e75b7 | ||
|
|
a41ae112a1 | ||
|
|
c28f478ea8 | ||
|
|
c18eb99d06 | ||
|
|
3a60f00d93 | ||
|
|
ee87778548 | ||
|
|
52c0c4d438 | ||
|
|
d117a4f022 | ||
|
|
6683d2d7a9 | ||
|
|
05357fe25e | ||
|
|
adc1825843 | ||
|
|
0c15169668 | ||
|
|
123dc1dcfb | ||
|
|
b2feafac09 | ||
|
|
b41ab8c550 | ||
|
|
62d5779bd5 | ||
|
|
f8b9d9802e | ||
|
|
dd8a1503b0 | ||
|
|
cff98ae900 | ||
|
|
9b108740da | ||
|
|
08a7bc7c9f | ||
|
|
fb256d7e5b | ||
|
|
710443b078 | ||
|
|
e0cde2f7c9 | ||
|
|
60b9c8de14 | ||
|
|
ecffe26be4 | ||
|
|
2570bd9e26 | ||
|
|
174f84514a | ||
|
|
65cb8d7b43 | ||
|
|
5f8ef808a3 | ||
|
|
4941ac70e0 | ||
|
|
67cd461145 | ||
|
|
92b5fc6f9a | ||
|
|
b90165b4e4 | ||
|
|
6c2dcb5c8a | ||
|
|
3efed32934 | ||
|
|
69737308fe | ||
|
|
a6dbea808a | ||
|
|
5131b17901 | ||
|
|
5f21c3a56d | ||
|
|
2350ac64ed | ||
|
|
d146127c18 | ||
|
|
abd65e103e | ||
|
|
bf65ea7bd0 | ||
|
|
73e278a8ed | ||
|
|
d92dfbbdb7 | ||
|
|
5c1e419eb5 | ||
|
|
124684f53f | ||
|
|
455b5d6758 | ||
|
|
c04e2e498b | ||
|
|
da8a45072f | ||
|
|
e1992e2054 | ||
|
|
c17cedd93a | ||
|
|
b6ad8f8790 | ||
|
|
5acc7eebc3 | ||
|
|
941927dfcd | ||
|
|
02933a9c93 | ||
|
|
e537651f29 | ||
|
|
af09fba755 | ||
|
|
04ea9018a3 | ||
|
|
ff7e1be24f | ||
|
|
fc4fd9e61c | ||
|
|
8908c7dcf9 | ||
|
|
b9996e2c1a | ||
|
|
afdc56f37c | ||
|
|
a25cd5dae8 | ||
|
|
447adb9090 | ||
|
|
92fd98d5ad | ||
|
|
c4001b4037 | ||
|
|
970a32287a | ||
|
|
17cd48dada | ||
|
|
ea3b6e955f | ||
|
|
843450bb9b | ||
|
|
e149af58b1 | ||
|
|
604a38035b | ||
|
|
cae38a365b | ||
|
|
e334246b46 | ||
|
|
36e013b40c | ||
|
|
f20cd6536e | ||
|
|
446bd35006 | ||
|
|
a377a7e315 | ||
|
|
3d046ac282 | ||
|
|
a08fa9a0e1 | ||
|
|
5856ed2836 | ||
|
|
d295355d99 | ||
|
|
77350f6119 | ||
|
|
bc2c2ebbfd | ||
|
|
1502e02a1a | ||
|
|
d0e2313a24 | ||
|
|
d8ba1a8ea7 | ||
|
|
ca7937fc4e | ||
|
|
df89bcceef | ||
|
|
cfccbe05c1 | ||
|
|
e352a6a1e7 | ||
|
|
8a3d992aaf | ||
|
|
c37f3d8d5b | ||
|
|
a96870e092 | ||
|
|
6bf1032237 | ||
|
|
3d816c747d | ||
|
|
3f2b96266b | ||
|
|
22b16d12eb | ||
|
|
c55b6f30df | ||
|
|
b7045d3d28 | ||
|
|
e31a404885 | ||
|
|
643588b71a | ||
|
|
a64c4d264d | ||
|
|
567780e188 | ||
|
|
1bc8529d83 | ||
|
|
6b480d7e87 | ||
|
|
083fd315e9 | ||
|
|
ef20e76174 | ||
|
|
8c8910808e | ||
|
|
f6ad379310 | ||
|
|
c5d6ce3e65 | ||
|
|
694dbc31c4 | ||
|
|
6488dc54e6 | ||
|
|
158da9b480 | ||
|
|
ec2e071ab7 | ||
|
|
465e270342 | ||
|
|
6705aff56f | ||
|
|
9069cfe1da | ||
|
|
677bb3ba6d | ||
|
|
cb253cff9e | ||
|
|
39ceb5ac5c | ||
|
|
d4edeaaf1b | ||
|
|
56aea1ffb8 | ||
|
|
09ab2af34c | ||
|
|
8bb26a6b0b | ||
|
|
3f2304549d | ||
|
|
ad72a435f1 | ||
|
|
f34332344e | ||
|
|
d324b57dd7 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,7 +12,6 @@ test-reports
|
||||
.pytest_cache
|
||||
venv
|
||||
*.noseids
|
||||
build
|
||||
*.egg-info
|
||||
.cache
|
||||
.mypy_cache
|
||||
|
||||
62
README.md
62
README.md
@@ -8,28 +8,43 @@
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://artifacthub.io/packages/search?repo=allegroai)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
<div align="center">
|
||||
|
||||
**v0.16 Upgrade Notice**
|
||||
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
|
||||
|
||||
</div>
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
|
||||
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
|
||||
due to Elasticsearch’s usage of the Java Security Manager.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
|
||||
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
|
||||
|
||||
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
|
||||
(which in any case **does not permit access to data within the Elasticsearch cluster**),
|
||||
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
|
||||
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
|
||||
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
|
||||
|
||||
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
|
||||
|
||||
---
|
||||
|
||||
### ClearML Server
|
||||
## ClearML Server
|
||||
#### *Formerly known as Trains Server*
|
||||
|
||||
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
|
||||
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
|
||||
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
|
||||
|
||||
The **ClearML Server** contains the following components:
|
||||
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
|
||||
## System design
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
The **ClearML Server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
|
||||
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
|
||||
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
|
||||
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
|
||||
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
|
||||
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
|
||||
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
|
||||
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
|
||||
To have the **ClearML** client use your **ClearML Server** instead:
|
||||
In order to set up the **ClearML** client to work with your **ClearML Server**:
|
||||
- Run the `clearml-init` command for an interactive setup.
|
||||
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
|
||||
|
||||
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
|
||||
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
|
||||
|
||||
## Restarting ClearML Server
|
||||
|
||||
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
```
|
||||
|
||||
1. Configure the ClearML-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
|
||||
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
export CLEARML_HOST_IP=server_host_ip_here
|
||||
export CLEARML_AGENT_GIT_USER=git_username_here
|
||||
export CLEARML_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
|
||||
@@ -205,12 +219,12 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
|
||||
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
|
||||
|
||||
@@ -26,6 +26,9 @@
|
||||
23: ["invalid_domain_name", "malformed domain name"]
|
||||
24: ["not_public_object", "object is not public"]
|
||||
|
||||
# Auth / Login
|
||||
75: ["invalid_access_key", "access key not found for user"]
|
||||
|
||||
# Tasks
|
||||
100: ["task_error", "general task error"]
|
||||
101: ["invalid_task_id", "invalid task id"]
|
||||
@@ -71,6 +74,8 @@
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
|
||||
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
@@ -84,7 +89,7 @@
|
||||
|
||||
# Database
|
||||
800: ["data_validation_error", "data validation error"]
|
||||
801: ["expected_unique_data", "value combination already exists"]
|
||||
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
|
||||
|
||||
# Workers
|
||||
1001: ["invalid_worker_id", "invalid worker id"]
|
||||
|
||||
@@ -61,12 +61,6 @@ class ListField(fields.ListField):
|
||||
item.validate()
|
||||
|
||||
|
||||
# since there is no distinction between None and empty DictField
|
||||
# this value can be used as sentinel in order to distinguish
|
||||
# between not set and empty DictField
|
||||
DictFieldNotSet = {}
|
||||
|
||||
|
||||
class DictField(fields.BaseField):
|
||||
types = (dict,)
|
||||
|
||||
|
||||
@@ -75,11 +75,17 @@ class CreateUserResponse(Base):
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
@@ -90,6 +96,11 @@ class GetCredentialsResponse(Base):
|
||||
credentials = ListField(CredentialsResponse)
|
||||
|
||||
|
||||
class EditCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class RevokeCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
@@ -14,12 +14,19 @@ from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class MetricVariants(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@@ -34,14 +41,16 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
)
|
||||
],
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
class MetricEventsRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
@@ -49,24 +58,36 @@ class DebugImagesRequest(Base):
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
|
||||
|
||||
class TaskMetricVariant(Base):
|
||||
class GetVariantSampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
|
||||
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
class GetMetricSamplesRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextHistorySampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
next_iteration: bool = BoolField(default=False)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
@@ -74,14 +95,36 @@ class LogOrderEnum(StringEnum):
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
class TaskEventsRequestBase(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
|
||||
|
||||
class TaskEventsRequest(TaskEventsRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField()
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
@@ -92,13 +135,40 @@ class MetricEvents(Base):
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
class DebugImageResponse(Base):
|
||||
class MetricEventsResponse(Base):
|
||||
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class TaskMetricsRequest(Base):
|
||||
class MultiTasksRequestBase(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
pass
|
||||
|
||||
|
||||
class TaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class ClearTaskLogRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
threshold_sec = IntField()
|
||||
allow_locked = BoolField(default=False)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
@@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
replace_metadata = BoolField(default=False)
|
||||
|
||||
@@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
@@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class ModelsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
from apiserver.apimodels import DictField
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
@@ -9,3 +11,13 @@ class Filter(models.Base):
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
|
||||
|
||||
class EntitiesCountRequest(models.Base):
|
||||
projects = DictField()
|
||||
tasks = DictField()
|
||||
models = DictField()
|
||||
pipelines = DictField()
|
||||
datasets = DictField()
|
||||
active_users = fields.ListField(str)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
|
||||
19
apiserver/apimodels/pipelines.py
Normal file
19
apiserver/apimodels/pipelines.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class Arg(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
value = fields.StringField(required=True)
|
||||
|
||||
|
||||
class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
@@ -1,6 +1,6 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
@@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@@ -51,10 +51,19 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
key = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_dataset_stats = fields.BoolField(default=False)
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
check_own_contents = fields.BoolField(default=False)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
|
||||
@@ -2,7 +2,7 @@ from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
@@ -19,13 +19,28 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetByIdRequest(QueueRequest):
|
||||
max_task_entries = IntField()
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
max_task_entries = IntField()
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
task = StringField()
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
@@ -34,7 +49,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
@@ -54,6 +69,7 @@ class GetMetricsRequest(Base):
|
||||
from_date = FloatField(required=True, validators=validators.Min(0))
|
||||
to_date = FloatField(required=True, validators=validators.Min(0))
|
||||
interval = IntField(required=True, validators=validators.Min(1))
|
||||
refresh = BoolField(default=False)
|
||||
|
||||
|
||||
class QueueMetrics(Base):
|
||||
|
||||
@@ -42,6 +42,7 @@ class StartedResponse(UpdateResponse):
|
||||
|
||||
class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class EnqueueBatchItem(UpdateBatchItem):
|
||||
@@ -50,6 +51,7 @@ class EnqueueBatchItem(UpdateBatchItem):
|
||||
|
||||
class EnqueueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
@@ -96,18 +98,29 @@ class UpdateRequest(TaskUpdateRequest):
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
requirements = DictField(required=True)
|
||||
|
||||
|
||||
class CompletedRequest(UpdateRequest):
|
||||
publish = BoolField(default=False)
|
||||
|
||||
|
||||
class CompletedResponse(UpdateResponse):
|
||||
published = IntField(default=0)
|
||||
|
||||
|
||||
class PublishRequest(UpdateRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
|
||||
@@ -171,6 +184,7 @@ class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
@@ -262,7 +276,9 @@ class StopManyRequest(TaskBatchRequest):
|
||||
|
||||
class EnqueueManyRequest(TaskBatchRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
validate_tasks = BoolField(default=False)
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteManyRequest(TaskBatchRequest):
|
||||
@@ -270,6 +286,7 @@ class DeleteManyRequest(TaskBatchRequest):
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class ResetManyRequest(TaskBatchRequest):
|
||||
@@ -277,6 +294,7 @@ class ResetManyRequest(TaskBatchRequest):
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishManyRequest(TaskBatchRequest):
|
||||
|
||||
@@ -20,6 +20,7 @@ DEFAULT_TIMEOUT = 10 * 60
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
@@ -96,6 +98,8 @@ class WorkerResponseEntry(WorkerEntry):
|
||||
|
||||
class GetAllRequest(Base):
|
||||
last_seen = IntField(default=3600)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class GetAllResponse(Base):
|
||||
|
||||
@@ -2,7 +2,11 @@ from datetime import datetime
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apiserver.apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
@@ -57,9 +61,10 @@ class AuthBLL:
|
||||
api_version=str(ServiceRepo.max_endpoint_version()),
|
||||
server_version=str(get_version()),
|
||||
server_build=str(get_build_number()),
|
||||
feature_set="basic",
|
||||
)
|
||||
|
||||
return GetTokenResponse(token=token.decode("ascii"))
|
||||
return GetTokenResponse(token=token)
|
||||
|
||||
@staticmethod
|
||||
def create_user(request: CreateUserRequest, call: APICall = None) -> str:
|
||||
@@ -144,7 +149,7 @@ class AuthBLL:
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -153,9 +158,11 @@ class AuthBLL:
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
cred = CredModel(
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
|
||||
)
|
||||
user.save()
|
||||
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
import operator
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugSampleHistoryState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
reached_first: bool = BoolField()
|
||||
reached_last: bool = BoolField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugSampleHistoryResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class DebugSampleHistory:
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugSampleHistoryState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_debug_image(
|
||||
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
Get the debug image for next/prev variant on the current iteration
|
||||
If does not exist then try getting image for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = DebugSampleHistoryResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
image = self._get_next_for_current_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
) or self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
if not image:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(image=image, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def _fill_res_and_update_state(
|
||||
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
|
||||
):
|
||||
state.variant = image["variant"]
|
||||
state.iteration = image["iter"]
|
||||
res.event = image
|
||||
var_state = first(s for s in state.variant_states if s.name == state.variant)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or image is found
|
||||
"""
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if cmp(var_state.name, state.variant)
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"metric": state.metric}},
|
||||
{"terms": {"variant": [v.name for v in variants]}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"variant": "desc" if navigate_earlier else "asc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_current_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the image falls in invalid range are discarded
|
||||
If no suitable image is found then None is returned
|
||||
"""
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"metric": state.metric}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = state.variant_states
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
]
|
||||
must_conditions.append({"bool": {"should": variants_conditions}})
|
||||
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_another_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_debug_image_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
Get the debug image for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = DebugSampleHistoryResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugSampleHistoryState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugSampleHistoryState):
|
||||
if state_.task != task or state_.metric != metric:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugSampleHistoryState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(s for s in state.variant_states if s.name == variant)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_for_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
image=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
|
||||
variant_iterations = self._get_variant_iterations(
|
||||
company_id=company_id, task=state.task, metric=state.metric
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for var_name, min_iter, max_iter in variant_iterations
|
||||
]
|
||||
|
||||
def _get_variant_iterations(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variants: Optional[Sequence[str]] = None,
|
||||
) -> Sequence[Tuple[str, int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported images
|
||||
The min iteration is the lowest iteration that contains non-recycled image url
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if variants:
|
||||
must.append({"terms": {"variant": variants}})
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
# all variants that sent debug images
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_iterations"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(
|
||||
es_res, ("aggregations", "variants", "buckets")
|
||||
)
|
||||
]
|
||||
@@ -6,45 +6,52 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
import elasticsearch
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
get_index_name,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
uncompress_plot,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
|
||||
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from apiserver.bll.event.event_metrics import EventMetrics
|
||||
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
MAX_LONG = 2**63 - 1
|
||||
MIN_LONG = -2**63
|
||||
MAX_LONG = 2 ** 63 - 1
|
||||
MIN_LONG = -(2 ** 63)
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
@@ -62,17 +69,32 @@ class EventBLL(object):
|
||||
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
_task_event_query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {"model_event": False}},
|
||||
{"bool": {"must_not": [{"exists": {"field": "model_event"}}]}},
|
||||
]
|
||||
}
|
||||
}
|
||||
_model_event_query = {"term": {"model_event": True}}
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._metrics = EventMetrics(self.es)
|
||||
self._skip_iteration_for_metric = set(
|
||||
config.get("services.events.ignore_iteration.metrics", [])
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
self.debug_images_iterator = MetricDebugImagesIterator(
|
||||
es=self.es, redis=self.redis
|
||||
)
|
||||
self.debug_image_sample_history = HistoryDebugImageIterator(
|
||||
es=self.es, redis=self.redis
|
||||
)
|
||||
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.events_iterator = EventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
@@ -84,18 +106,42 @@ class EventBLL(object):
|
||||
if not task_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
with translate_errors_context():
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not model_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context():
|
||||
query = Q(id__in=model_ids, company=company_id)
|
||||
if not allow_locked_models:
|
||||
query &= Q(ready__ne=True)
|
||||
res = Model.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
self, company_id, events, worker, allow_locked=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
task_ids = set()
|
||||
model_events = events[0].get("model_event", False)
|
||||
for event in events:
|
||||
if event.get("model_event", model_events) != model_events:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Inconsistent model_event setting in the passed events"
|
||||
)
|
||||
if event.pop("allow_locked", allow_locked) != allow_locked:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Inconsistent allow_locked setting in the passed events"
|
||||
)
|
||||
|
||||
actions: List[dict] = []
|
||||
task_or_model_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_scalar_events = nested_dict(
|
||||
3, dict
|
||||
@@ -105,13 +151,28 @@ class EventBLL(object):
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
errors_per_type = defaultdict(int)
|
||||
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
if model_events:
|
||||
for event in events:
|
||||
model = event.pop("model", None)
|
||||
if model is not None:
|
||||
event["task"] = model
|
||||
valid_entities = self._get_valid_models(
|
||||
company_id,
|
||||
model_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_models=allow_locked,
|
||||
)
|
||||
entity_name = "model"
|
||||
else:
|
||||
valid_entities = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked,
|
||||
)
|
||||
entity_name = "task"
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
@@ -125,13 +186,17 @@ class EventBLL(object):
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
if model_events and event_type == EventType.task_log.value:
|
||||
errors_per_type[f"Task log events are not supported for models"] += 1
|
||||
continue
|
||||
|
||||
task_or_model_id = event.get("task")
|
||||
if task_or_model_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
if task_or_model_id not in valid_entities:
|
||||
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
@@ -153,10 +218,13 @@ class EventBLL(object):
|
||||
# force iter to be a long int
|
||||
iter = event.get("iter")
|
||||
if iter is not None:
|
||||
iter = int(iter)
|
||||
if iter > MAX_LONG or iter < MIN_LONG:
|
||||
errors_per_type[invalid_iteration_error] += 1
|
||||
continue
|
||||
if model_events:
|
||||
iter = 0
|
||||
else:
|
||||
iter = int(iter)
|
||||
if iter > MAX_LONG or iter < MIN_LONG:
|
||||
errors_per_type[invalid_iteration_error] += 1
|
||||
continue
|
||||
event["iter"] = iter
|
||||
|
||||
# used to have "values" to indicate array. no need anymore
|
||||
@@ -166,6 +234,7 @@ class EventBLL(object):
|
||||
|
||||
event["metric"] = event.get("metric") or ""
|
||||
event["variant"] = event.get("variant") or ""
|
||||
event["model_event"] = model_events
|
||||
|
||||
index_name = get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
@@ -180,24 +249,28 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_ids.add(task_id)
|
||||
task_or_model_ids.add(task_or_model_id)
|
||||
if (
|
||||
iter is not None
|
||||
and not model_events
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
task_iteration[task_or_model_id] = max(
|
||||
iter, task_iteration[task_or_model_id]
|
||||
)
|
||||
|
||||
if not model_events:
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_or_model_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id],
|
||||
event=event,
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
@@ -216,39 +289,41 @@ class EventBLL(object):
|
||||
with translate_errors_context():
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
elasticsearch.helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
if not model_events:
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
for task_or_model_id in task_or_model_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
task_id=task_or_model_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
iter_max=task_iteration.get(task_or_model_id),
|
||||
last_scalar_events=task_last_scalar_events.get(
|
||||
task_or_model_id
|
||||
),
|
||||
last_events=task_last_events.get(task_or_model_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
remaining_tasks.add(task_or_model_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
@@ -260,7 +335,8 @@ class EventBLL(object):
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
if invalid_iterations_count:
|
||||
raise BulkIndexError(
|
||||
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
|
||||
f"{invalid_iterations_count} document(s) failed to index.",
|
||||
[invalid_iteration_error],
|
||||
)
|
||||
|
||||
if not added:
|
||||
@@ -303,11 +379,7 @@ class EventBLL(object):
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def uncompress_plots(self, plot_events: Sequence[dict]):
|
||||
for event in plot_events:
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
uncompress_plot(event)
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_json(text: str) -> bool:
|
||||
@@ -387,33 +459,17 @@ class EventBLL(object):
|
||||
as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
|
||||
update time.
|
||||
"""
|
||||
fields = {}
|
||||
|
||||
if iter_max is not None:
|
||||
fields["last_iteration_max"] = iter_max
|
||||
|
||||
if last_scalar_events:
|
||||
fields["last_scalar_values"] = list(
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if last_events:
|
||||
fields["last_events"] = last_events
|
||||
|
||||
if not fields:
|
||||
if iter_max is None and not last_events and not last_scalar_events:
|
||||
return False
|
||||
|
||||
return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
|
||||
return TaskBLL.update_statistics(
|
||||
task_id,
|
||||
company_id,
|
||||
last_update=now,
|
||||
last_iteration_max=iter_max,
|
||||
last_scalar_events=last_scalar_events,
|
||||
last_events=last_events,
|
||||
)
|
||||
|
||||
def _get_event_id(self, event):
|
||||
id_values = (str(event[field]) for field in self.id_fields if field in event)
|
||||
@@ -432,7 +488,7 @@ class EventBLL(object):
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
size = min(batch_size, 10000)
|
||||
@@ -445,7 +501,7 @@ class EventBLL(object):
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
@@ -466,24 +522,35 @@ class EventBLL(object):
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // num_last_iterations)
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@@ -499,15 +566,11 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
@@ -527,12 +590,15 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
model_events: bool = False,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = EventType.metrics_plot
|
||||
@@ -553,8 +619,10 @@ class EventBLL(object):
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
if last_iterations_per_plot is None or model_events:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
@@ -563,6 +631,7 @@ class EventBLL(object):
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@@ -593,14 +662,14 @@ class EventBLL(object):
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@@ -618,11 +687,47 @@ class EventBLL(object):
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
self.clear_scroll(next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_debug_image_urls(
|
||||
self, company_id: str, task_id: str, after_key: dict = None
|
||||
) -> Tuple[Sequence[str], Optional[dict]]:
|
||||
if check_empty_data(self.es, company_id, EventType.metrics_image):
|
||||
return [], None
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"debug_images": {
|
||||
"composite": {
|
||||
"size": 1000,
|
||||
**({"after": after_key} if after_key else {}),
|
||||
"sources": [{"url": {"terms": {"field": "url"}}}],
|
||||
}
|
||||
}
|
||||
},
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_response = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_image,
|
||||
body=es_req,
|
||||
)
|
||||
res = nested_get(es_response, ("aggregations", "debug_images"))
|
||||
if not res:
|
||||
return [], None
|
||||
|
||||
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
|
||||
|
||||
def get_plot_image_urls(
|
||||
self, company_id: str, task_id: str, scroll_id: Optional[str]
|
||||
) -> Tuple[Sequence[dict], Optional[str]]:
|
||||
@@ -669,48 +774,48 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
no_scroll=False,
|
||||
model_events=False,
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
if last_iter_count is None or model_events:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
)
|
||||
}
|
||||
for task, last_iters in tasks_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
@@ -724,14 +829,14 @@ class EventBLL(object):
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@@ -748,35 +853,36 @@ class EventBLL(object):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@@ -787,33 +893,40 @@ class EventBLL(object):
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||
def get_task_latest_scalar_values(
|
||||
self, company_id, task_id
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
return [], 0
|
||||
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@@ -841,12 +954,8 @@ class EventBLL(object):
|
||||
},
|
||||
"_source": {"excludes": []},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@@ -891,7 +1000,7 @@ class EventBLL(object):
|
||||
"_source": ["iter", "value"],
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
@@ -905,73 +1014,185 @@ class EventBLL(object):
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(
|
||||
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||
):
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
) -> Mapping[str, Sequence]:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
return {}
|
||||
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
return {
|
||||
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
}
|
||||
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
with translate_errors_context():
|
||||
extra_msg = None
|
||||
query = Q(id=task_id, company=company_id)
|
||||
if not allow_locked:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
extra_msg = "or task published"
|
||||
res = Task.objects(query).only("id").first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
@staticmethod
|
||||
def _validate_model_state(
|
||||
company_id: str, model_id: str, allow_locked: bool = False
|
||||
):
|
||||
extra_msg = None
|
||||
query = Q(id=model_id, company=company_id)
|
||||
if not allow_locked:
|
||||
query &= Q(ready__ne=True)
|
||||
extra_msg = "or model published"
|
||||
res = Model.objects(query).only("id").first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
extra_msg, company=company_id, id=model_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
|
||||
extra_msg = None
|
||||
query = Q(id=task_id, company=company_id)
|
||||
if not allow_locked:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
extra_msg = "or task published"
|
||||
res = Task.objects(query).only("id").first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_events_deletion_params(async_delete: bool) -> dict:
|
||||
if async_delete:
|
||||
return {
|
||||
"wait_for_completion": False,
|
||||
"requests_per_second": config.get(
|
||||
"services.events.max_async_deleted_events_per_sec", 1000
|
||||
),
|
||||
}
|
||||
|
||||
return {"refresh": True}
|
||||
|
||||
def delete_task_events(
|
||||
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
|
||||
):
|
||||
if model:
|
||||
self._validate_model_state(
|
||||
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
|
||||
)
|
||||
else:
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
**self._get_events_deletion_params(async_delete),
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
if not async_delete:
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
|
||||
def clear_task_log(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
allow_locked: bool = False,
|
||||
threshold_sec: int = None,
|
||||
):
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
if check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.task_log
|
||||
):
|
||||
return 0
|
||||
|
||||
with translate_errors_context():
|
||||
must = [{"term": {"task": task_id}}]
|
||||
sort = None
|
||||
if threshold_sec:
|
||||
timestamp_ms = int(threshold_sec * 1000)
|
||||
must.append(
|
||||
{
|
||||
"range": {
|
||||
"timestamp": {
|
||||
"lt": (es_factory.get_timestamp_millis() - timestamp_ms)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
sort = {"timestamp": {"order": "desc"}}
|
||||
es_req = {
|
||||
"query": {"bool": {"must": must}},
|
||||
**({"sort": sort} if sort else {}),
|
||||
}
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.task_log,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
)
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def delete_multi_task_events(
|
||||
self, company_id: str, task_ids: Sequence[str], async_delete=False
|
||||
):
|
||||
"""
|
||||
Delete mutliple task events. No check is done for tasks write access
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
es_req = {"query": {"terms": {"task": task_ids}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
|
||||
with translate_errors_context():
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
**self._get_events_deletion_params(async_delete),
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
if not async_delete:
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def clear_scroll(self, scroll_id: str):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.es.clear_scroll(scroll_id=scroll_id)
|
||||
except elasticsearch.exceptions.NotFoundError:
|
||||
pass
|
||||
except elasticsearch.exceptions.RequestError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
log.exception("Failed clearing scroll %s", scroll_id)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import base64
|
||||
import zlib
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, Mapping, Tuple
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
@@ -16,7 +20,13 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
_max_es_allowed_aggregation_buckets = 10000
|
||||
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
|
||||
@@ -28,17 +38,23 @@ class EventSettings:
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def max_metrics_count(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_count", 100)
|
||||
|
||||
@classproperty
|
||||
def max_variants_count(self):
|
||||
return config.get("services.events.events_retrieval.max_variants_count", 100)
|
||||
def max_es_buckets(self):
|
||||
percentage = (
|
||||
min(
|
||||
100,
|
||||
config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count_threshold",
|
||||
80,
|
||||
),
|
||||
)
|
||||
/ 100
|
||||
)
|
||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id}"
|
||||
return f"events-{event_type}-{company_id.lower()}"
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
@@ -63,4 +79,83 @@ def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
|
||||
|
||||
|
||||
def count_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.count(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_max_metric_and_variant_counts(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
query: dict,
|
||||
**kwargs,
|
||||
) -> Tuple[int, int]:
|
||||
dynamic = config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count", False
|
||||
)
|
||||
max_metrics_count = config.get(
|
||||
"services.events.events_retrieval.max_metrics_count", 100
|
||||
)
|
||||
max_variants_count = config.get(
|
||||
"services.events.events_retrieval.max_variants_count", 100
|
||||
)
|
||||
if not dynamic:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
|
||||
)
|
||||
|
||||
metrics_count = safe_get(
|
||||
es_res, "aggregations/metrics_count/value", max_metrics_count
|
||||
)
|
||||
if not metrics_count:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
|
||||
|
||||
|
||||
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"variant": variants}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if variants
|
||||
else {"term": {"metric": metric}}
|
||||
for metric, variants in metric_variants.items()
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
|
||||
|
||||
def uncompress_plot(event: dict):
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
|
||||
@@ -4,23 +4,25 @@ from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Mapping
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
SINGLE_SCALAR_ITERATION,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -34,7 +36,12 @@ class EventMetrics:
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
@@ -46,7 +53,12 @@ class EventMetrics:
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
@@ -57,6 +69,7 @@ class EventMetrics:
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
@@ -64,6 +77,7 @@ class EventMetrics:
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
@@ -95,40 +109,19 @@ class EventMetrics:
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
tasks: Sequence[Task],
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company", "company_origin"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.get_index_company() for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
event_type = EventType.metrics_scalar
|
||||
company_id = next(iter(companies))
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in tasks}
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
@@ -137,6 +130,7 @@ class EventMetrics:
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
task_ids = [t.id for t in tasks]
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
@@ -152,6 +146,57 @@ class EventMetrics:
|
||||
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, company_id: str, tasks: Sequence[Task]
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
if check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
):
|
||||
return {}
|
||||
|
||||
task_ids = [t.id for t in tasks]
|
||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
||||
|
||||
def _get_value(event: dict):
|
||||
return {
|
||||
field: event.get(field)
|
||||
for field in ("metric", "variant", "value", "timestamp")
|
||||
}
|
||||
|
||||
return {
|
||||
task: [_get_value(e) for e in events]
|
||||
for task, events in bucketize(task_events, itemgetter("task")).items()
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
) -> Sequence[dict]:
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_scalar,
|
||||
)
|
||||
if not es_res["hits"]["total"]["value"]:
|
||||
return []
|
||||
|
||||
return [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@@ -197,6 +242,7 @@ class EventMetrics:
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
@@ -204,21 +250,30 @@ class EventMetrics:
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = self._task_conditions(task_id)
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@@ -232,10 +287,7 @@ class EventMetrics:
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
@@ -287,33 +339,40 @@ class EventMetrics:
|
||||
"""
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
aggs=aggs,
|
||||
task_id=task_id,
|
||||
metrics=metrics,
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
@@ -340,19 +399,18 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
aggs: dict,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
@staticmethod
|
||||
def _task_conditions(task_id: str) -> list:
|
||||
return [
|
||||
{"term": {"task": task_id}},
|
||||
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_task_metrics_query(
|
||||
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
|
||||
):
|
||||
must = cls._task_conditions(task_id)
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
@@ -367,20 +425,9 @@ class EventMetrics:
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
}
|
||||
return {"bool": {"must": must}}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
def get_tasks_metrics(
|
||||
def get_task_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
@@ -406,22 +453,21 @@ class EventMetrics:
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": self._task_conditions(task_id)}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": EventSettings.max_es_buckets,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
||||
197
apiserver/bll/event/events_iterator.py
Normal file
197
apiserver/bll/event/events_iterator.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from typing import Optional, Tuple, Sequence, Any
|
||||
|
||||
import attr
|
||||
import jsonmodels.models
|
||||
import jwt
|
||||
from elasticsearch import Elasticsearch
|
||||
from jwt.algorithms import get_default_algorithms
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
count_company_events,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class EventsIterator:
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_key_value: Optional[Any] = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
|
||||
**kwargs,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
from_key_value = kwargs.pop("from_timestamp", from_key_value)
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
event_type=event_type,
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=ScalarKey.resolve(key),
|
||||
)
|
||||
return res
|
||||
|
||||
def count_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return 0
|
||||
|
||||
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_result = count_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_result["count"]
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
key: ScalarKey,
|
||||
from_key_value: Optional[Any],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If from_key_field is not set then start either from latest or earliest.
|
||||
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
|
||||
so that events with this value will not be lost between the calls.
|
||||
"""
|
||||
query, must = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": query,
|
||||
"sort": {key.field: "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_key_value:
|
||||
es_req["search_after"] = [from_key_value]
|
||||
|
||||
with translate_errors_context():
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": must + [{"term": {key.field: events[-1][key.field]}}]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_initial_query_and_must(
|
||||
task_id: str, metric_variants: MetricVariants = None
|
||||
) -> Tuple[dict, list]:
|
||||
if not metric_variants:
|
||||
must = [{"term": {"task": task_id}}]
|
||||
query = {"term": {"task": task_id}}
|
||||
else:
|
||||
must = [
|
||||
{"term": {"task": task_id}},
|
||||
get_metric_variants_condition(metric_variants),
|
||||
]
|
||||
query = {"bool": {"must": must}}
|
||||
return query, must
|
||||
|
||||
|
||||
class Scroll(jsonmodels.models.Base):
|
||||
def get_scroll_id(self) -> str:
|
||||
return jwt.encode(
|
||||
self.to_struct(),
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_scroll_id(cls, scroll_id: str):
|
||||
try:
|
||||
return cls(
|
||||
**jwt.decode(
|
||||
scroll_id,
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
algorithms=get_default_algorithms(),
|
||||
)
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise ValueError("Invalid Scroll ID")
|
||||
455
apiserver/bll/event/history_debug_image_iterator.py
Normal file
455
apiserver/bll/event/history_debug_image_iterator.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, ListField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugImageSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class VariantSampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryDebugImageIterator:
|
||||
event_type = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = VariantSampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if next_iteration:
|
||||
event = self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
else:
|
||||
# noinspection PyArgumentList
|
||||
event = first(
|
||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||
for f in (
|
||||
self._get_next_for_current_iteration,
|
||||
self._get_next_for_another_iteration,
|
||||
)
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
event: dict, res: VariantSampleResult, state: DebugImageSampleState
|
||||
):
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@staticmethod
|
||||
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
|
||||
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in metric_variants
|
||||
]
|
||||
return {"bool": {"should": variants_conditions}}
|
||||
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
_get_variants_conditions(metric_variants),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, metric_variants in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = VariantSampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugImageSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugImageSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugImageSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
316
apiserver/bll/event/history_plots_iterator.py
Normal file
316
apiserver/bll/event/history_plots_iterator.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, ListField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import (
|
||||
EventType,
|
||||
uncompress_plot,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
name: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class PlotsSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
metric_states: Sequence[MetricState] = ListField([MetricState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricSamplesResult(object):
|
||||
scroll_id: str = None
|
||||
events: list = []
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryPlotsIterator:
|
||||
event_type = EventType.metrics_plot
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=PlotsSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the samples for next/prev metric on the current iteration
|
||||
If does not exist then try getting sample for the first/last metric from next/prev iteration
|
||||
"""
|
||||
res = MetricSamplesResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
]
|
||||
if state.navigate_current_metric:
|
||||
must_conditions.append({"term": {"metric": state.metric}})
|
||||
|
||||
next_iteration_condition = {
|
||||
"range": {"iter": {range_operator: state.iteration}}
|
||||
}
|
||||
if next_iteration or state.navigate_current_metric:
|
||||
must_conditions.append(next_iteration_condition)
|
||||
else:
|
||||
next_metric_condition = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"range": {"metric": {range_operator: state.metric}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
must_conditions.append(
|
||||
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
|
||||
)
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order=order,
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def get_samples_for_metric(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = MetricSamplesResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: PlotsSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: PlotsSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
state: PlotsSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
|
||||
if not metric_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order="desc",
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
return res
|
||||
|
||||
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
|
||||
metrics = self._get_metric_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.metric_states = [
|
||||
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for metric, (min_iter, max_iter) in metrics.items()
|
||||
]
|
||||
|
||||
def _get_metric_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 5000,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
)
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: (
|
||||
int(metric_bucket["first_iter"]["value"]),
|
||||
int(metric_bucket["last_iter"]["value"]),
|
||||
)
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
|
||||
):
|
||||
for event in events:
|
||||
uncompress_plot(event)
|
||||
state.metric = events[0]["metric"]
|
||||
state.iteration = events[0]["iter"]
|
||||
res.events = events
|
||||
metric_state = first(
|
||||
ms for ms in state.metric_states if ms.name == state.metric
|
||||
)
|
||||
if metric_state:
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
def _get_metric_events_for_condition(
|
||||
self, company_id: str, task: str, order: str, must_conditions: Sequence
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1,
|
||||
"order": {"_key": order},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"variant": {"order": "asc"}},
|
||||
"size": 100,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
for level in ("iters", "metrics"):
|
||||
level_data = aggs_result[level]["buckets"]
|
||||
if not level_data:
|
||||
return []
|
||||
aggs_result = level_data[0]
|
||||
|
||||
return [
|
||||
hit["_source"]
|
||||
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
|
||||
]
|
||||
@@ -1,127 +0,0 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = EventType.task_log
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
53
apiserver/bll/event/metric_debug_images_iterator.py
Normal file
53
apiserver/bll/event/metric_debug_images_iterator.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import EventType
|
||||
from .metric_events_iterator import MetricEventsIterator, VariantState
|
||||
|
||||
|
||||
class MetricDebugImagesIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_image)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return [{"exists": {"field": "url"}}]
|
||||
|
||||
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
aggs = {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
|
||||
"""If the image urls get recycled then fill the last_invalid_iteration field"""
|
||||
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
|
||||
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
|
||||
|
||||
return aggs, fill_variant_state_data
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"url": {"order": "desc"}}
|
||||
@@ -1,8 +1,9 @@
|
||||
import abc
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Callable
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
@@ -18,12 +19,14 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
@@ -48,25 +51,24 @@ class TaskScrollState(Base):
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
class MetricEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
class MetricEventsResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
class MetricEventsIterator:
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
state_class=MetricEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
@@ -74,19 +76,19 @@ class DebugImagesIterator:
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugImagesResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return DebugImagesResult()
|
||||
) -> MetricEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.event_type):
|
||||
return MetricEventsResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
def init_state(state_: MetricEventsScrollState):
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
def validate_state(state_: MetricEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
@@ -98,7 +100,13 @@ class DebugImagesIterator:
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
res = MetricEventsResult(next_scroll_id=state.id)
|
||||
specific_variants_requested = any(
|
||||
variants
|
||||
for t, metrics in task_metrics.items()
|
||||
if metrics
|
||||
for m, variants in metrics.items()
|
||||
)
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
@@ -107,6 +115,7 @@ class DebugImagesIterator:
|
||||
company_id=company_id,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
specific_variants_requested=specific_variants_requested,
|
||||
),
|
||||
state.tasks,
|
||||
)
|
||||
@@ -117,11 +126,11 @@ class DebugImagesIterator:
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
state: MetricEventsScrollState,
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new debug image events were added
|
||||
Determine the metrics for which new event_type events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
@@ -131,7 +140,7 @@ class DebugImagesIterator:
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
) -> Mapping[str, datetime]:
|
||||
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
|
||||
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return {}
|
||||
@@ -139,10 +148,10 @@ class DebugImagesIterator:
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return {
|
||||
stats.metric: stats.event_stats_by_type[
|
||||
self.EVENT_TYPE.value
|
||||
self.event_type.value
|
||||
].last_update
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
if self.event_type.value in stats.event_stats_by_type
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
}
|
||||
|
||||
@@ -158,11 +167,11 @@ class DebugImagesIterator:
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = set(
|
||||
m
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
)
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
@@ -196,7 +205,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
@@ -212,25 +221,45 @@ class DebugImagesIterator:
|
||||
for task, metric_states in zip(task_metrics, task_metric_states)
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variant_state_aggs(
|
||||
self,
|
||||
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
pass
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
for the variants that reported any event_type events
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||
if metrics:
|
||||
must.append({"terms": {"metric": list(metrics)}})
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@@ -238,52 +267,33 @@ class DebugImagesIterator:
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
**(
|
||||
{"aggs": variant_state_aggs}
|
||||
if variant_state_aggs
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
)
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_state(variant: dict):
|
||||
"""
|
||||
Return new variant state for the passed variant bucket
|
||||
If the image urls get recycled then fill the last_invalid_iteration field
|
||||
"""
|
||||
state = VariantState(variant=variant["key"])
|
||||
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
|
||||
if fill_variant_state_data:
|
||||
fill_variant_state_data(variant, state)
|
||||
|
||||
return state
|
||||
|
||||
return [
|
||||
@@ -298,12 +308,21 @@ class DebugImagesIterator:
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
pass
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
specific_variants_requested: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
@@ -319,7 +338,7 @@ class DebugImagesIterator:
|
||||
must_conditions = [
|
||||
{"term": {"task": task_state.task}},
|
||||
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
|
||||
{"exists": {"field": "url"}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
|
||||
range_condition = None
|
||||
@@ -330,6 +349,8 @@ class DebugImagesIterator:
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
metrics_count = len(task_state.metrics)
|
||||
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
@@ -344,20 +365,20 @@ class DebugImagesIterator:
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"url": {"order": "desc"}}
|
||||
"sort": self._get_same_variant_events_order()
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -368,9 +389,9 @@ class DebugImagesIterator:
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return task_state.task, []
|
||||
@@ -380,18 +401,26 @@ class DebugImagesIterator:
|
||||
for m in task_state.metrics
|
||||
for v in m.variants
|
||||
}
|
||||
allow_uninitialized = (
|
||||
False
|
||||
if specific_variants_requested
|
||||
else config.get(
|
||||
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
def is_valid_event(event: dict) -> bool:
|
||||
key = event.get("metric"), event.get("variant")
|
||||
if key not in invalid_iterations:
|
||||
return False
|
||||
return allow_uninitialized
|
||||
|
||||
max_invalid = invalid_iterations[key]
|
||||
return max_invalid is None or event.get("iter") > max_invalid
|
||||
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
ev["_source"]
|
||||
self._process_event(ev["_source"])
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .metric_events_iterator import MetricEventsIterator
|
||||
|
||||
|
||||
class MetricPlotsIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variant_state_aggs(self):
|
||||
return None, None
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"timestamp": {"order": "desc"}}
|
||||
@@ -4,6 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from typing import Any
|
||||
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.config_repo import config
|
||||
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
def cast_value(self, value: Any) -> Any:
|
||||
"""Cast value to appropriate type"""
|
||||
return value
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Tuple, Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
@@ -7,6 +9,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
@@ -23,6 +26,33 @@ class ModelBLL:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id,
|
||||
model_ids,
|
||||
only=None,
|
||||
allow_public=False,
|
||||
return_models=True,
|
||||
) -> Optional[Sequence[Model]]:
|
||||
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
||||
ids = set(model_ids)
|
||||
query = Q(id__in=ids)
|
||||
|
||||
q = Model.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidModelId(ids=model_ids)
|
||||
|
||||
if return_models:
|
||||
return list(q)
|
||||
|
||||
@classmethod
|
||||
def publish_model(
|
||||
cls,
|
||||
@@ -127,3 +157,33 @@ class ModelBLL:
|
||||
)
|
||||
|
||||
return unarchived
|
||||
|
||||
@classmethod
|
||||
def get_model_stats(
|
||||
cls, company: str, model_ids: Sequence[str],
|
||||
) -> Dict[str, dict]:
|
||||
if not model_ids:
|
||||
return {}
|
||||
|
||||
result = Model.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": {"$in": model_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {"labels_count": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
return {
|
||||
r.pop("_id"): r
|
||||
for r in result
|
||||
}
|
||||
|
||||
108
apiserver/bll/model/metadata.py
Normal file
108
apiserver/bll/model/metadata.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Sequence, Union, Mapping
|
||||
|
||||
from mongoengine import Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Metadata:
|
||||
@staticmethod
|
||||
def metadata_from_api(
|
||||
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||
) -> dict:
|
||||
if not api_data:
|
||||
return {}
|
||||
|
||||
if isinstance(api_data, dict):
|
||||
return {
|
||||
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||
for k, v in api_data.items()
|
||||
}
|
||||
|
||||
return {
|
||||
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_metadata(
|
||||
cls,
|
||||
obj: Document,
|
||||
items: Sequence[MetadataItem],
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||
for prefix in (
|
||||
"metadata.",
|
||||
"-metadata.",
|
||||
):
|
||||
paths = [
|
||||
cls._process_path(path) if path.startswith(prefix) else path
|
||||
for path in paths
|
||||
]
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = GetMixin.get_projection(call_data)
|
||||
if projection:
|
||||
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||
|
||||
ordering = GetMixin.get_ordering(call_data)
|
||||
if ordering:
|
||||
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
@@ -6,6 +6,7 @@ from redis import Redis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@@ -42,6 +43,8 @@ class _TagsCache:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project__in=project_ids_with_children([project]))
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .project_queries import ProjectQueries
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
@@ -14,19 +14,21 @@ from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
Any,
|
||||
)
|
||||
|
||||
from boltons.iterutils import partition
|
||||
from mongoengine import Q, Document
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
from apiserver.database.utils import get_options, get_company_or_none_constraint
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .sub_projects import (
|
||||
_reposition_project_with_children,
|
||||
@@ -54,46 +56,53 @@ class ProjectBLL:
|
||||
Remove the source project
|
||||
Return the amounts of moved entities and subprojects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "move_project"):
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
parent=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
source=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
if destination_id:
|
||||
destination = Project.get(company, destination_id)
|
||||
if source_id in destination.path:
|
||||
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
|
||||
source=source_id, destination=destination_id
|
||||
)
|
||||
else:
|
||||
destination = None
|
||||
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
)[source.id]
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
)[source.id]
|
||||
if destination:
|
||||
cls.validate_projects_depth(
|
||||
projects=children,
|
||||
old_parent_depth=len(source.path) + 1,
|
||||
new_parent_depth=len(destination.path) + 1,
|
||||
)
|
||||
|
||||
moved_entities = 0
|
||||
for entity_type in (Task, Model):
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
moved_entities = 0
|
||||
for entity_type in (Task, Model):
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(
|
||||
project=child,
|
||||
children=[c for c in children if c.parent == child.id],
|
||||
parent=destination,
|
||||
)
|
||||
moved_sub_projects += 1
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(
|
||||
project=child,
|
||||
children=[c for c in children if c.parent == child.id],
|
||||
parent=destination,
|
||||
)
|
||||
moved_sub_projects += 1
|
||||
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
|
||||
return moved_entities, moved_sub_projects, affected
|
||||
|
||||
@@ -116,72 +125,76 @@ class ProjectBLL:
|
||||
it should be writable. The source location should be writable too.
|
||||
Return the number of moved projects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "move_project"):
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
)
|
||||
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
cls.validate_projects_depth(
|
||||
projects=[project, *children],
|
||||
old_parent_depth=len(project.path),
|
||||
new_parent_depth=_get_project_depth(new_location),
|
||||
)
|
||||
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
cls.validate_projects_depth(
|
||||
projects=[project, *children],
|
||||
old_parent_depth=len(project.path),
|
||||
new_parent_depth=_get_project_depth(new_location),
|
||||
if new_parent and (
|
||||
project_id == new_parent.id or project_id in new_parent.path
|
||||
):
|
||||
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
|
||||
project=project_id, parent=new_parent.id
|
||||
)
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
return moved, affected
|
||||
return moved, affected
|
||||
|
||||
@classmethod
|
||||
def update(cls, company: str, project_id: str, **fields):
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
fields["basename"] = new_name.split("/")[-1]
|
||||
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
if new_name:
|
||||
old_name = project.name
|
||||
project.name = new_name
|
||||
children = _get_sub_projects(
|
||||
[project.id], _only=("id", "name", "path")
|
||||
)[project.id]
|
||||
_update_subproject_names(
|
||||
project=project, children=children, old_name=old_name
|
||||
)
|
||||
if new_name:
|
||||
old_name = project.name
|
||||
project.name = new_name
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
_update_subproject_names(
|
||||
project=project, children=children, old_name=old_name
|
||||
)
|
||||
|
||||
return updated
|
||||
return updated
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -193,6 +206,7 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new project.
|
||||
@@ -208,6 +222,7 @@ class ProjectBLL:
|
||||
user=user,
|
||||
company=company,
|
||||
name=name,
|
||||
basename=name.split("/")[-1],
|
||||
description=description,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
@@ -215,7 +230,12 @@ class ProjectBLL:
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
parent = _ensure_project(company=company, user=user, name=location)
|
||||
parent = _ensure_project(
|
||||
company=company,
|
||||
user=user,
|
||||
name=location,
|
||||
creation_params=parent_creation_params,
|
||||
)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
@@ -233,13 +253,14 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find a project named `project_name` or create a new one.
|
||||
Returns project ID
|
||||
"""
|
||||
if not project_id and not project_name:
|
||||
raise ValueError("project id or name required")
|
||||
raise errors.bad_request.ValidationError("project id or name required")
|
||||
|
||||
if project_id:
|
||||
project = Project.objects(company=company, id=project_id).only("id").first()
|
||||
@@ -260,6 +281,7 @@ class ProjectBLL:
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
parent_creation_params=parent_creation_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -275,26 +297,26 @@ class ProjectBLL:
|
||||
"""
|
||||
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
|
||||
"""
|
||||
with TimingContext("mongo", "move_under_project"):
|
||||
project = cls.find_or_create(
|
||||
user=user,
|
||||
company=company,
|
||||
project_id=project,
|
||||
project_name=project_name,
|
||||
description="",
|
||||
)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
project = cls.find_or_create(
|
||||
user=user,
|
||||
company=company,
|
||||
project_id=project,
|
||||
project_name=project_name,
|
||||
description="",
|
||||
)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
|
||||
return project
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
@@ -302,6 +324,8 @@ class ProjectBLL:
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
@@ -325,10 +349,12 @@ class ProjectBLL:
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
@@ -356,6 +382,37 @@ class ProjectBLL:
|
||||
},
|
||||
]
|
||||
|
||||
def completed_after_subquery(additional_cond, time_thresh: datetime):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed after the time_thresh
|
||||
"if": {
|
||||
"$and": [
|
||||
"$completed",
|
||||
{"$gt": ["$completed", time_thresh]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
"then": 1,
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def max_started_subquery(condition):
|
||||
return {
|
||||
"$max": {
|
||||
"$cond": {
|
||||
"if": condition,
|
||||
"then": "$started",
|
||||
"else": datetime.min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
@@ -386,25 +443,54 @@ class ProjectBLL:
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
time_thresh = datetime.utcnow() - timedelta(hours=24)
|
||||
for state in cls.visibility_states:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery(
|
||||
{"$not": cls.archived_tasks_cond}
|
||||
cond = (
|
||||
cls.archived_tasks_cond
|
||||
if state == EntityVisibility.archived
|
||||
else {"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
group_step[state.value] = runtime_subquery(cond)
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
|
||||
|
||||
def add_state_to_filter(f: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
if not specific_state:
|
||||
return f
|
||||
|
||||
f = f or {}
|
||||
new_f = {k: v for k, v in f.items() if k != "system_tags"}
|
||||
system_tags = [
|
||||
tag
|
||||
for tag in f.get("system_tags", [])
|
||||
if tag
|
||||
not in (
|
||||
EntityVisibility.archived.value,
|
||||
f"-{EntityVisibility.archived.value}",
|
||||
)
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
||||
]
|
||||
|
||||
if specific_state == EntityVisibility.archived:
|
||||
system_tags.append(EntityVisibility.archived.value)
|
||||
else:
|
||||
system_tags.append(f"-{EntityVisibility.archived.value}")
|
||||
new_f["system_tags"] = system_tags
|
||||
|
||||
return new_f
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
filter_=add_state_to_filter(filter_),
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
@@ -439,17 +525,65 @@ class ProjectBLL:
|
||||
aggregated[pid] = reduce(func, relevant_data)
|
||||
return aggregated
|
||||
|
||||
@classmethod
|
||||
def get_dataset_stats(
|
||||
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
task_runtime_pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls.get_match_conditions(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
users=users,
|
||||
filter_={
|
||||
"system_tags": [f"-{EntityVisibility.archived.value}"]
|
||||
},
|
||||
),
|
||||
"runtime": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"project": 1, "runtime": 1, "last_update": 1}},
|
||||
{"$sort": {"project": 1, "last_update": 1}},
|
||||
{"$group": {"_id": "$project", "runtime": {"$last": "$runtime"}}},
|
||||
]
|
||||
|
||||
return {
|
||||
r["_id"]: {
|
||||
"file_count": r["runtime"].get("ds_file_count", 0),
|
||||
"total_size": r["runtime"].get("ds_total_size", 0),
|
||||
}
|
||||
for r in Task.aggregate(task_runtime_pipeline)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_project_stats(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
include_children: bool = True,
|
||||
search_hidden: bool = False,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
user_active_project_ids: Sequence[str] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
|
||||
child_projects = (
|
||||
_get_sub_projects(
|
||||
project_ids,
|
||||
_only=("id", "name"),
|
||||
search_hidden=search_hidden,
|
||||
allowed_ids=user_active_project_ids,
|
||||
)
|
||||
if include_children
|
||||
else {}
|
||||
)
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
@@ -457,6 +591,8 @@ class ProjectBLL:
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
@@ -483,8 +619,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: {
|
||||
status: nested_get(a, (section, status), 0)
|
||||
+ nested_get(b, (section, status), 0)
|
||||
status: nested_get(a, (section, status), default=0)
|
||||
+ nested_get(b, (section, status), default=0)
|
||||
for status in set(a.get(section, {})) | set(b.get(section, {}))
|
||||
}
|
||||
for section in set(a) | set(b)
|
||||
@@ -507,6 +643,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
if not section.endswith("max_task_started")
|
||||
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
@@ -518,15 +656,30 @@ class ProjectBLL:
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
project_runtime = runtime.get(project_id, {})
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
|
||||
def get_time_or_none(value):
|
||||
return value if value != datetime.min else None
|
||||
|
||||
return {
|
||||
"total_runtime": nested_get(runtime, (project_id, section), 0),
|
||||
"status_count": nested_get(
|
||||
status_count, (project_id, section), default_counts
|
||||
"status_count": project_section_statuses,
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks_24h": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
"last_task_run": get_time_or_none(
|
||||
project_runtime.get(f"{section}_max_task_started", datetime.min)
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
s
|
||||
for s in cls.visibility_states
|
||||
if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
stats = {
|
||||
@@ -554,26 +707,48 @@ class ProjectBLL:
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models/dataviews in the given projects
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_project_tags(
|
||||
cls,
|
||||
company_id: str,
|
||||
include_system: bool,
|
||||
projects: Sequence[str] = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
@@ -582,7 +757,7 @@ class ProjectBLL:
|
||||
users: Sequence[str],
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
) -> Sequence[str]:
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Get the projects ids where user created any tasks including all the parents of these projects
|
||||
If project ids are specified then filter the results by these project ids
|
||||
@@ -606,13 +781,16 @@ class ProjectBLL:
|
||||
|
||||
res = list(res)
|
||||
if not res:
|
||||
return res
|
||||
return res, res
|
||||
|
||||
ids_with_parents = _ids_with_parents(res)
|
||||
if project_ids:
|
||||
return [pid for pid in ids_with_parents if pid in project_ids]
|
||||
user_active_project_ids = _ids_with_parents(res)
|
||||
filtered_ids = (
|
||||
list(set(user_active_project_ids) & set(project_ids))
|
||||
if project_ids
|
||||
else list(user_active_project_ids)
|
||||
)
|
||||
|
||||
return ids_with_parents
|
||||
return filtered_ids, user_active_project_ids
|
||||
|
||||
@classmethod
|
||||
def get_task_parents(
|
||||
@@ -627,10 +805,14 @@ class ProjectBLL:
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
|
||||
if projects:
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
@@ -658,6 +840,8 @@ class ProjectBLL:
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@@ -673,11 +857,51 @@ class ProjectBLL:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@staticmethod
|
||||
def get_match_conditions(
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
filter_: Mapping[str, Any],
|
||||
users: Sequence[str],
|
||||
):
|
||||
conditions = {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
if users:
|
||||
conditions["user"] = {"$in": users}
|
||||
|
||||
if not filter_:
|
||||
return conditions
|
||||
|
||||
for field, field_filter in filter_.items():
|
||||
if not (
|
||||
field_filter
|
||||
and isinstance(field_filter, list)
|
||||
and all(isinstance(t, str) for t in field_filter)
|
||||
):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
)
|
||||
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
|
||||
conditions[field] = {
|
||||
**({"$in": include} if include else {}),
|
||||
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
|
||||
}
|
||||
|
||||
return conditions
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
def calc_own_contents(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/dataviews/models per requested project
|
||||
Use separate aggregation calls on Task/Dataview/Model instead of lookup
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
aggregation on projects in order not to hit memory limits on large tasks
|
||||
"""
|
||||
if not project_ids:
|
||||
@@ -685,35 +909,23 @@ class ProjectBLL:
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
{
|
||||
"$project": {"project": 1}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$project",
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
}
|
||||
{"$project": {"project": 1}},
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {
|
||||
data["_id"]: data["count"]
|
||||
for data in cls_.aggregate(pipeline)
|
||||
}
|
||||
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||
|
||||
with TimingContext("mongo", "get_security_groups"):
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {
|
||||
"own_tasks": tasks.get(pid, 0),
|
||||
"own_models": models.get(pid, 0),
|
||||
}
|
||||
for pid in project_ids
|
||||
}
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
|
||||
@@ -8,17 +8,18 @@ from apiserver.bll.task.task_cleanup import (
|
||||
collect_debug_image_urls,
|
||||
collect_plot_image_urls,
|
||||
TaskUrls,
|
||||
_schedule_for_delete,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -30,40 +31,94 @@ class DeleteProjectResult:
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def validate_project_delete(company: str, project_id: str):
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
is_pipeline = "pipeline" in (project.system_tags or [])
|
||||
project_ids = _ids_with_children([project_id])
|
||||
ret = {}
|
||||
for cls in (Task, Model):
|
||||
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
|
||||
for cls in (Task, Model):
|
||||
query = dict(
|
||||
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
|
||||
)
|
||||
name = f"non_archived_{cls.__name__.lower()}s"
|
||||
if not is_pipeline:
|
||||
ret[name] = cls.objects(**query).count()
|
||||
else:
|
||||
ret[name] = (
|
||||
cls.objects(**query, type=TaskType.controller).count()
|
||||
if cls == Task
|
||||
else 0
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str, project_id: str, force: bool, delete_contents: bool
|
||||
company: str,
|
||||
user: str,
|
||||
project_id: str,
|
||||
force: bool,
|
||||
delete_contents: bool,
|
||||
delete_external_artifacts=True,
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path")
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", False
|
||||
)
|
||||
is_pipeline = "pipeline" in (project.system_tags or [])
|
||||
project_ids = _ids_with_children([project_id])
|
||||
if not force:
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
non_archived = cls.objects(
|
||||
project__in=project_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).only("id")
|
||||
query = dict(
|
||||
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
|
||||
)
|
||||
if not is_pipeline:
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
non_archived = cls.objects(**query).only("id")
|
||||
if non_archived:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
else:
|
||||
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
|
||||
if non_archived:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
raise errors.bad_request.ProjectHasTasks(
|
||||
"please archive all the runs inside the project", id=project_id
|
||||
)
|
||||
|
||||
if not delete_contents:
|
||||
with TimingContext("mongo", "update_children"):
|
||||
for cls in (Model, Task):
|
||||
updated_count = cls.objects(project__in=project_ids).update(
|
||||
project=None
|
||||
)
|
||||
for cls in (Model, Task):
|
||||
updated_count = cls.objects(project__in=project_ids).update(project=None)
|
||||
res = DeleteProjectResult(disassociated_tasks=updated_count)
|
||||
else:
|
||||
deleted_models, model_urls = _delete_models(projects=project_ids)
|
||||
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
|
||||
deleted_models, model_event_urls, model_urls = _delete_models(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
event_urls = task_event_urls | model_event_urls
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=project_id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=True,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
res = DeleteProjectResult(
|
||||
deleted_tasks=deleted_tasks,
|
||||
deleted_models=deleted_models,
|
||||
@@ -92,9 +147,8 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = {t.id for t in tasks}
|
||||
with TimingContext("mongo", "delete_tasks_update_children"):
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
|
||||
event_urls, artifact_urls = set(), set()
|
||||
for task in tasks:
|
||||
@@ -109,46 +163,58 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
|
||||
}
|
||||
)
|
||||
|
||||
event_bll.delete_multi_task_events(company, list(task_ids))
|
||||
event_bll.delete_multi_task_events(
|
||||
company, list(task_ids), async_delete=async_events_delete
|
||||
)
|
||||
deleted = tasks.delete()
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
|
||||
def _delete_models(
|
||||
company: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set[str], Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
that reference them to reference None.
|
||||
"""
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set()
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set(), set()
|
||||
|
||||
model_ids = list({m.id for m in models})
|
||||
model_ids = list({m.id for m in models})
|
||||
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
event_urls, model_urls = set(), set()
|
||||
for m in models:
|
||||
event_urls.update(collect_debug_image_urls(company, m.id))
|
||||
event_urls.update(collect_plot_image_urls(company, m.id))
|
||||
if m.uri:
|
||||
model_urls.add(m.uri)
|
||||
|
||||
urls = {m.uri for m in models if m.uri}
|
||||
deleted = models.delete()
|
||||
return deleted, urls
|
||||
event_bll.delete_multi_task_events(
|
||||
company, model_ids, async_delete=async_events_delete
|
||||
)
|
||||
deleted = models.delete()
|
||||
return deleted, event_urls, model_urls
|
||||
|
||||
370
apiserver/bll/project/project_queries.py
Normal file
370
apiserver/bll/project/project_queries.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectQueries:
|
||||
def __init__(self, redis=None):
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
"""
|
||||
If passed projects is None means top level projects
|
||||
If passed projects is empty means no project filtering
|
||||
"""
|
||||
if include_subprojects:
|
||||
if not project_ids:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
if project_ids is None:
|
||||
project_ids = [None]
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
return {"project": {"$in": project_ids}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
if allow_public:
|
||||
return {"company": {"$in": [None, "", company_id]}}
|
||||
|
||||
return {"company": company_id}
|
||||
|
||||
@classmethod
|
||||
def get_aggregated_project_parameters(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "section"))
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "name"))
|
||||
),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
ParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_param_values(
|
||||
self, key: str, last_update: datetime, allowed_delta_sec=0
|
||||
) -> Optional[ParamValues]:
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving params cached values: {str(ex)}")
|
||||
|
||||
def get_task_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
last_update=last_update,
|
||||
allowed_delta_sec=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
),
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
def get_model_metadata_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
key: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
|
||||
last_updated_model = (
|
||||
Model.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_model:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
|
||||
last_update = last_updated_model.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.models.metadata_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
@@ -4,6 +4,7 @@ from typing import Tuple, Optional, Sequence, Mapping
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
|
||||
name_separator = "/"
|
||||
@@ -25,7 +26,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
def _ensure_project(
|
||||
company: str, user: str, name: str, creation_params: dict = None
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
@@ -48,9 +51,10 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
description="",
|
||||
basename=name.split("/")[-1],
|
||||
**(creation_params or dict(description="")),
|
||||
)
|
||||
parent = _ensure_project(company, user, location)
|
||||
parent = _ensure_project(company, user, location, creation_params=creation_params)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
@@ -98,12 +102,21 @@ def _get_writable_project_from_name(
|
||||
|
||||
|
||||
def _get_sub_projects(
|
||||
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
|
||||
project_ids: Sequence[str],
|
||||
_only: Sequence[str] = ("id", "path"),
|
||||
search_hidden=True,
|
||||
allowed_ids: Sequence[str] = None,
|
||||
) -> Mapping[str, Sequence[Project]]:
|
||||
"""
|
||||
Return the list of child projects of all the levels for the parent project ids
|
||||
"""
|
||||
qs = Project.objects(path__in=project_ids)
|
||||
query = dict(path__in=project_ids)
|
||||
if not search_hidden:
|
||||
query["system_tags__nin"] = [EntityVisibility.hidden.value]
|
||||
if allowed_ids:
|
||||
query["id__in"] = allowed_ids
|
||||
|
||||
qs = Project.objects(**query)
|
||||
if _only:
|
||||
_only = set(_only) | {"path"}
|
||||
qs = qs.only(*_only)
|
||||
|
||||
@@ -3,8 +3,10 @@ from datetime import datetime
|
||||
from typing import Callable, Sequence, Optional, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue.queue_metrics import QueueMetrics
|
||||
@@ -32,7 +34,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Sequence[dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
@@ -50,8 +52,25 @@ class QueueBLL(object):
|
||||
queue.save()
|
||||
return queue
|
||||
|
||||
def get_by_name(
|
||||
self, company_id: str, queue_name: str, only: Optional[Sequence[str]] = None,
|
||||
) -> Queue:
|
||||
qs = Queue.objects(name=queue_name, company=company_id)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
|
||||
return qs.first()
|
||||
|
||||
@staticmethod
|
||||
def _get_task_entries_projection(max_task_entries: int) -> dict:
|
||||
return dict(slice__entries=max_task_entries)
|
||||
|
||||
def get_by_id(
|
||||
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
only: Optional[Sequence[str]] = None,
|
||||
max_task_entries: int = None,
|
||||
) -> Queue:
|
||||
"""
|
||||
Get queue by id
|
||||
@@ -62,6 +81,8 @@ class QueueBLL(object):
|
||||
qs = Queue.objects(**query)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
if max_task_entries:
|
||||
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
|
||||
queue = qs.first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
@@ -120,20 +141,72 @@ class QueueBLL(object):
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if queue.entries and not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
if queue.entries:
|
||||
if not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
from apiserver.bll.task import ChangeStatusRequest
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get_for_writing(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=["id", "status", "enqueue_status", "project"],
|
||||
)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason="Queue was deleted",
|
||||
status_message="",
|
||||
).execute(enqueue_status=None)
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Failed dequeuing task {item.task} from queue: {queue_id}"
|
||||
)
|
||||
|
||||
queue.delete()
|
||||
|
||||
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id, parameters=query_dict, query_dict=query_dict
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
|
||||
for worker in self.worker_bll.get_all(company_id):
|
||||
if queue_id in worker.queues:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
@@ -142,7 +215,12 @@ class QueueBLL(object):
|
||||
res = Queue.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
override_projection=projection,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
@@ -173,13 +251,15 @@ class QueueBLL(object):
|
||||
if any(e.task == task_id for e in queue.entries):
|
||||
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
entry = Entry(added=datetime.utcnow(), task=task_id)
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
push__entries=entry, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
@@ -187,16 +267,22 @@ class QueueBLL(object):
|
||||
|
||||
return res
|
||||
|
||||
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
|
||||
def get_next_task(
|
||||
self, company_id: str, queue_id: str, task_id: str = None
|
||||
) -> Optional[Entry]:
|
||||
"""
|
||||
Atomically pop and return the first task from the queue (or None)
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
|
||||
queue = Queue.objects(
|
||||
**query, **({"entries__0__task": task_id} if task_id else {})
|
||||
).modify(pop__entries=-1, upsert=False)
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
if not task_id or not Queue.objects(**query).first():
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
return
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
@@ -219,7 +305,6 @@ class QueueBLL(object):
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
entries_to_remove = [e for e in queue.entries if e.task == task_id]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
@@ -227,6 +312,9 @@ class QueueBLL(object):
|
||||
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
return len(entries_to_remove) if res else 0
|
||||
|
||||
def reposition_task(
|
||||
@@ -270,3 +358,22 @@ class QueueBLL(object):
|
||||
)
|
||||
|
||||
return new_position
|
||||
|
||||
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
|
||||
res = next(
|
||||
Queue.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": queue_id,
|
||||
}
|
||||
},
|
||||
{"$project": {"count": {"$size": "$entries"}}},
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if res is None:
|
||||
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
|
||||
return int(res.get("count"))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Sequence
|
||||
|
||||
import elasticsearch.helpers
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
@@ -11,25 +13,30 @@ from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
_conf = config.get("services.queues")
|
||||
_queue_metrics_key_pattern = "queue_metrics_{queue}"
|
||||
redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
class EsKeys:
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
QUEUE_FIELD = "queue"
|
||||
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
QUEUE_FIELD = "queue"
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def _queue_metrics_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"queue_metrics_{company_id}_"
|
||||
return f"queue_metrics_{company_id.lower()}_"
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
@@ -49,7 +56,7 @@ class QueueMetrics:
|
||||
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
|
||||
return total_waiting_in_secs / len(entries)
|
||||
|
||||
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
|
||||
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> int:
|
||||
"""
|
||||
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
|
||||
:return: True if the write to es was successful, false otherwise
|
||||
@@ -63,23 +70,22 @@ class QueueMetrics:
|
||||
|
||||
def make_doc(queue: Queue) -> dict:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
|
||||
entries
|
||||
),
|
||||
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
|
||||
},
|
||||
)
|
||||
return {
|
||||
EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
EsKeys.QUEUE_FIELD: queue.id,
|
||||
EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(entries),
|
||||
EsKeys.QUEUE_LENGTH_FIELD: len(entries),
|
||||
}
|
||||
|
||||
actions = list(map(make_doc, queues))
|
||||
logged = 0
|
||||
for q in queues:
|
||||
queue_doc = make_doc(q)
|
||||
self.es.index(index=es_index, body=queue_doc)
|
||||
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
|
||||
redis.set(redis_key, json.dumps(queue_doc))
|
||||
logged += 1
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
return logged
|
||||
|
||||
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
|
||||
query = dict(company=company_id)
|
||||
@@ -90,8 +96,7 @@ class QueueMetrics:
|
||||
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -105,13 +110,13 @@ class QueueMetrics:
|
||||
return {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"field": EsKeys.TIMESTAMP_FIELD,
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
"queues": {
|
||||
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
|
||||
"terms": {"field": EsKeys.QUEUE_FIELD},
|
||||
"aggs": cls._get_top_waiting_agg(),
|
||||
}
|
||||
},
|
||||
@@ -128,13 +133,13 @@ class QueueMetrics:
|
||||
"top_avg_waiting": {
|
||||
"top_hits": {
|
||||
"sort": [
|
||||
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
|
||||
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
|
||||
{EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
|
||||
{EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
|
||||
],
|
||||
"_source": {
|
||||
"includes": [
|
||||
cls.EsKeys.WAITING_TIME_FIELD,
|
||||
cls.EsKeys.QUEUE_LENGTH_FIELD,
|
||||
EsKeys.WAITING_TIME_FIELD,
|
||||
EsKeys.QUEUE_LENGTH_FIELD,
|
||||
]
|
||||
},
|
||||
"size": 1,
|
||||
@@ -149,6 +154,7 @@ class QueueMetrics:
|
||||
to_date: float,
|
||||
interval: int,
|
||||
queue_ids: Sequence[str],
|
||||
refresh: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the company queue metrics in the specified time range.
|
||||
@@ -158,7 +164,8 @@ class QueueMetrics:
|
||||
In case no queue ids are specified the avg across all the
|
||||
company queues is calculated for each metric
|
||||
"""
|
||||
# self._log_current_metrics(company, queue_ids=queue_ids)
|
||||
if refresh:
|
||||
self._log_current_metrics(company_id, queue_ids=queue_ids)
|
||||
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
@@ -174,7 +181,7 @@ class QueueMetrics:
|
||||
"aggs": self._get_dates_agg(interval),
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
|
||||
with translate_errors_context():
|
||||
res = self._search_company_metrics(company_id, es_req)
|
||||
|
||||
if "aggregations" not in res:
|
||||
@@ -256,7 +263,52 @@ class QueueMetrics:
|
||||
continue
|
||||
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
|
||||
queue_metrics[queue_data["key"]] = {
|
||||
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
|
||||
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
|
||||
"queue_length": res[EsKeys.QUEUE_LENGTH_FIELD],
|
||||
"avg_waiting_time": res[EsKeys.WAITING_TIME_FIELD],
|
||||
}
|
||||
return queue_metrics
|
||||
|
||||
|
||||
class MetricsRefresher:
|
||||
threads = ThreadsManager()
|
||||
|
||||
@classproperty
|
||||
def watch_interval_sec(self):
|
||||
return _conf.get("metrics_refresh_interval_sec", 300)
|
||||
|
||||
@classmethod
|
||||
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
|
||||
def start(cls, queue_metrics: QueueMetrics = None):
|
||||
if not cls.watch_interval_sec:
|
||||
return
|
||||
|
||||
if not queue_metrics:
|
||||
from .queue_bll import QueueBLL
|
||||
|
||||
queue_metrics = QueueBLL().metrics
|
||||
|
||||
sleep(10)
|
||||
while True:
|
||||
try:
|
||||
for queue in Queue.objects():
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
doc_time = 0
|
||||
try:
|
||||
redis_key = _queue_metrics_key_pattern.format(queue=queue.id)
|
||||
data = redis.get(redis_key)
|
||||
if data:
|
||||
queue_doc = json.loads(data)
|
||||
doc_time = int(queue_doc.get(EsKeys.TIMESTAMP_FIELD))
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Error reading queue metrics data for queue {queue.id}: {str(ex)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not doc_time
|
||||
or (timestamp - doc_time) > cls.watch_interval_sec * 1000
|
||||
):
|
||||
queue_metrics.log_queue_metrics_to_es(queue.company, [queue])
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed collecting queue metrics: {str(ex)}")
|
||||
sleep(60)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -31,24 +30,36 @@ class RedisCacheManager(Generic[T]):
|
||||
|
||||
def set_state(self, state: T) -> None:
|
||||
redis_key = self._get_redis_key(state.id)
|
||||
with TimingContext("redis", "cache_set_state"):
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
|
||||
def get_state(self, state_id) -> Optional[T]:
|
||||
redis_key = self._get_redis_key(state_id)
|
||||
with TimingContext("redis", "cache_get_state"):
|
||||
response = self.redis.get(redis_key)
|
||||
response = self.redis.get(redis_key)
|
||||
if response:
|
||||
return self.state_class.from_json(response)
|
||||
|
||||
def delete_state(self, state_id) -> None:
|
||||
with TimingContext("redis", "cache_delete_state"):
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
def get_or_create_state_core(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
) -> T:
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
return state
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
@@ -66,12 +77,9 @@ class RedisCacheManager(Generic[T]):
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
state = self.get_or_create_state_core(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
)
|
||||
|
||||
try:
|
||||
yield state
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from threading import Thread, Lock
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
|
||||
import attr
|
||||
@@ -9,76 +9,83 @@ import psutil
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
|
||||
class ResourceMonitor(Thread):
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
stat_threads = ThreadsManager("Statistics")
|
||||
|
||||
@classmethod
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
def __init__(self, sample_interval_sec=5):
|
||||
super(ResourceMonitor, self).__init__(daemon=True)
|
||||
self.sample_interval_sec = sample_interval_sec
|
||||
self._lock = Lock()
|
||||
self._clear()
|
||||
|
||||
def _clear(self):
|
||||
sample = self._get_sample()
|
||||
self._avg = sample
|
||||
self._min = sample
|
||||
self._max = sample
|
||||
self._clear_time = datetime.utcnow()
|
||||
self._count = 1
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
|
||||
@classmethod
|
||||
def _get_sample(cls) -> Sample:
|
||||
return cls.Sample(
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_current_sample(cls) -> "Sample":
|
||||
return cls(
|
||||
cpu_usage=psutil.cpu_percent(),
|
||||
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
|
||||
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
|
||||
)
|
||||
|
||||
def run(self):
|
||||
while not ThreadsManager.terminating:
|
||||
sleep(self.sample_interval_sec)
|
||||
|
||||
sample = self._get_sample()
|
||||
class ResourceMonitor:
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
sample = Sample.get_current_sample()
|
||||
self.avg = sample
|
||||
self.min = sample
|
||||
self.max = sample
|
||||
self.time = datetime.utcnow()
|
||||
self.count = 1
|
||||
|
||||
with self._lock:
|
||||
self._min = self._min.min(sample)
|
||||
self._max = self._max.max(sample)
|
||||
self._avg = self._avg.avg(sample, self._count)
|
||||
self._count += 1
|
||||
def add_sample(self, sample: Sample):
|
||||
self.min = self.min.min(sample)
|
||||
self.max = self.max.max(sample)
|
||||
self.avg = self.avg.avg(sample, self.count)
|
||||
self.count += 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
sample_interval_sec = 5
|
||||
_lock = Lock()
|
||||
accumulator = Accumulator()
|
||||
|
||||
@classmethod
|
||||
@stat_threads.register("resource_monitor", daemon=True)
|
||||
def start(cls):
|
||||
while True:
|
||||
sleep(cls.sample_interval_sec)
|
||||
sample = Sample.get_current_sample()
|
||||
with cls._lock:
|
||||
cls.accumulator.add_sample(sample)
|
||||
|
||||
@classmethod
|
||||
def get_stats(cls) -> dict:
|
||||
""" Returns current resource statistics and clears internal resource statistics """
|
||||
with self._lock:
|
||||
min_ = attr.asdict(self._min)
|
||||
max_ = attr.asdict(self._max)
|
||||
avg = attr.asdict(self._avg)
|
||||
interval = datetime.utcnow() - self._clear_time
|
||||
self._clear()
|
||||
with cls._lock:
|
||||
min_ = attr.asdict(cls.accumulator.min)
|
||||
max_ = attr.asdict(cls.accumulator.max)
|
||||
avg = attr.asdict(cls.accumulator.avg)
|
||||
interval = datetime.utcnow() - cls.accumulator.time
|
||||
cls.accumulator = cls.Accumulator()
|
||||
|
||||
return {
|
||||
"interval_sec": interval.total_seconds(),
|
||||
|
||||
@@ -21,9 +21,8 @@ from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.json import dumps
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
from apiserver.version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor
|
||||
from .resource_monitor import ResourceMonitor, stat_threads
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -31,17 +30,19 @@ worker_bll = WorkerBLL()
|
||||
|
||||
|
||||
class StatisticsReporter:
|
||||
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
|
||||
send_queue = queue.Queue()
|
||||
supported = config.get("apiserver.statistics.supported", True)
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
ResourceMonitor.start()
|
||||
cls.start_sender()
|
||||
cls.start_reporter()
|
||||
|
||||
@classmethod
|
||||
@threads.register("reporter", daemon=True)
|
||||
@stat_threads.register("reporter", daemon=True)
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
@@ -54,7 +55,7 @@ class StatisticsReporter:
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
)
|
||||
sleep(report_interval.total_seconds())
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
for company in Company.objects(
|
||||
defaults__stats_option__enabled=True
|
||||
@@ -68,7 +69,7 @@ class StatisticsReporter:
|
||||
sleep(report_interval.total_seconds())
|
||||
|
||||
@classmethod
|
||||
@threads.register("sender", daemon=True)
|
||||
@stat_threads.register("sender", daemon=True)
|
||||
def start_sender(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
@@ -85,7 +86,7 @@ class StatisticsReporter:
|
||||
|
||||
WarningFilter.attach()
|
||||
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
report = cls.send_queue.get()
|
||||
|
||||
@@ -111,7 +112,7 @@ class StatisticsReporter:
|
||||
"uuid": get_server_uuid(),
|
||||
"queues": {"count": Queue.objects(company=company_id).count()},
|
||||
"users": {"count": User.objects(company=company_id).count()},
|
||||
"resources": cls.threads.resource_monitor.get_stats(),
|
||||
"resources": ResourceMonitor.get_stats(),
|
||||
"experiments": next(
|
||||
iter(cls._get_experiments_stats(company_id).values()), {}
|
||||
),
|
||||
|
||||
@@ -5,7 +5,6 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
@@ -53,23 +52,18 @@ class Artifacts:
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "update_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
@@ -79,19 +73,14 @@ class Artifacts:
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
|
||||
@@ -15,7 +15,6 @@ from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@@ -68,36 +67,35 @@ class HyperParams:
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
@@ -108,34 +106,31 @@ class HyperParams:
|
||||
replace_hyperparams: str,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{mongoengine_safe(section)}"
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
@@ -191,17 +186,16 @@ class HyperParams:
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
tasks = Task.aggregate(pipeline)
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
@@ -212,36 +206,30 @@ class HyperParams:
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
|
||||
@@ -39,7 +39,7 @@ class NonResponsiveTasksWatchdog:
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import dpath
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
|
||||
(("execution", "model_desc"), "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
legacy_params = nested_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
not fields.get(new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
@@ -117,21 +118,34 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
nested_set(fields, new_path, new_param)
|
||||
nested_delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
def ensure_non_empty(k: str, desc: str) -> str:
|
||||
if not k:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Empty {desc} name is not allowed"
|
||||
)
|
||||
return k
|
||||
|
||||
params = fields.get("hyperparams")
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(key, "section")): {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(k, "parameter")): v
|
||||
for k, v in value.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields["hyperparams"] = escaped_params
|
||||
|
||||
params = fields.get("configuration")
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(key, "configuration")): value
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields["configuration"] = escaped_params
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
@@ -140,7 +154,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
@@ -150,18 +164,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
fields[param_field] = unescaped_params
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
("hyperparams", ("execution", "parameters"), True),
|
||||
("configuration", ("execution", "model_desc"), False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
fields.get(new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
nested_set(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
@@ -174,7 +188,7 @@ def _process_path(path: str):
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
if len(parts) < 2 or len(parts) > 4:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
@@ -184,8 +198,8 @@ def _process_path(path: str):
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
("execution.docker_cmd", "container")
|
||||
("execution.model_desc", "configuration"),
|
||||
("execution.docker_cmd", "container"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
@@ -14,7 +11,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -38,8 +35,6 @@ from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
@@ -70,11 +65,10 @@ class TaskBLL:
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
with TimingContext("mongo", "task_with_access"):
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
@@ -92,15 +86,14 @@ class TaskBLL:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -115,7 +108,7 @@ class TaskBLL:
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
with translate_errors_context():
|
||||
ids = set(task_ids)
|
||||
q = Task.get_many(
|
||||
company=company_id,
|
||||
@@ -264,58 +257,55 @@ class TaskBLL:
|
||||
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else None
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else task.id
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
|
||||
return new_task, new_project_data
|
||||
|
||||
@@ -350,54 +340,6 @@ class TaskBLL:
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(
|
||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
@@ -433,7 +375,7 @@ class TaskBLL:
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
last_events: Dict[str, Dict[str, dict]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
@@ -458,18 +400,43 @@ class TaskBLL:
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_scalar_values is not None:
|
||||
if last_scalar_events is not None:
|
||||
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
||||
total_metrics = set()
|
||||
if max_values:
|
||||
query = dict(id=task_id)
|
||||
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
||||
if to_add <= max_values:
|
||||
query[f"unique_metrics__{max_values-to_add}__exists"] = True
|
||||
task = Task.objects(**query).only("unique_metrics").first()
|
||||
if task and task.unique_metrics:
|
||||
total_metrics = set(task.unique_metrics)
|
||||
|
||||
def op_path(op, *path):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
new_metrics = []
|
||||
for metric_key, metric_data in last_scalar_events.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
metric = (
|
||||
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
||||
)
|
||||
if max_values:
|
||||
if (
|
||||
len(total_metrics) >= max_values
|
||||
and metric not in total_metrics
|
||||
):
|
||||
continue
|
||||
total_metrics.add(metric)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
new_metrics.append(metric)
|
||||
path = f"last_metrics__{metric_key}__{variant_key}"
|
||||
for key, value in variant_data.items():
|
||||
if key == "min_value":
|
||||
extra_updates[f"min__{path}__min_value"] = value
|
||||
elif key == "max_value":
|
||||
extra_updates[f"max__{path}__max_value"] = value
|
||||
elif key in ("metric", "variant", "value"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
if new_metrics:
|
||||
extra_updates["add_to_set__unique_metrics"] = new_metrics
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
@@ -494,178 +461,15 @@ class TaskBLL:
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**project_constraint,
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
):
|
||||
cls.dequeue(task, company_id)
|
||||
try:
|
||||
cls.dequeue(task, company_id)
|
||||
except errors.bad_request.InvalidQueueOrTaskNotQueued:
|
||||
# dequeue may fail if the queue was deleted
|
||||
pass
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
|
||||
@@ -1,64 +1,31 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
|
||||
from typing import Sequence, Set, Tuple
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition
|
||||
from mongoengine import QuerySet, Document
|
||||
from boltons.iterutils import partition, bucketize, first
|
||||
from mongoengine import NotUniqueError
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_bll import PlotFields
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.database.model.url_to_delete import (
|
||||
StorageType,
|
||||
UrlToDelete,
|
||||
FileType,
|
||||
DeletionStatus,
|
||||
)
|
||||
from apiserver.database.utils import id as db_id
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
T = TypeVar("T", bound=Document)
|
||||
|
||||
|
||||
class DocumentGroup(List[T]):
|
||||
"""
|
||||
Operate on a list of documents as if they were a query result
|
||||
"""
|
||||
|
||||
def __init__(self, document_type: Type[T], documents: Iterable[T]):
|
||||
super(DocumentGroup, self).__init__(documents)
|
||||
self.type = document_type
|
||||
|
||||
@property
|
||||
def ids(self) -> Set[str]:
|
||||
return {obj.id for obj in self}
|
||||
|
||||
def objects(self, *args, **kwargs) -> QuerySet:
|
||||
return self.type.objects(id__in=self.ids, *args, **kwargs)
|
||||
|
||||
|
||||
class TaskOutputs(Generic[T]):
|
||||
"""
|
||||
Split task outputs of the same type by the ready state
|
||||
"""
|
||||
|
||||
published: DocumentGroup[T]
|
||||
draft: DocumentGroup[T]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_published: Callable[[T], bool],
|
||||
document_type: Type[T],
|
||||
children: Iterable[T],
|
||||
):
|
||||
"""
|
||||
:param is_published: predicate returning whether items is considered published
|
||||
:param document_type: type of output
|
||||
:param children: output documents
|
||||
"""
|
||||
self.published, self.draft = map(
|
||||
lambda x: DocumentGroup(document_type, x),
|
||||
partition(children, key=is_published),
|
||||
)
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -101,64 +68,112 @@ class CleanupResult:
|
||||
)
|
||||
|
||||
|
||||
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
|
||||
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
urls = set()
|
||||
next_scroll_id = None
|
||||
with TimingContext("es", "collect_plot_image_urls"):
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
"""
|
||||
metrics = event_bll.get_metrics_and_variants(
|
||||
company_id=company, task_id=task, event_type=EventType.metrics_image
|
||||
)
|
||||
if not metrics:
|
||||
return set()
|
||||
|
||||
task_metrics = {task: set(metrics)}
|
||||
scroll_id = None
|
||||
after_key = None
|
||||
urls = set()
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=100,
|
||||
state_id=scroll_id,
|
||||
res, after_key = event_bll.get_debug_image_urls(
|
||||
company_id=company, task_id=task_or_model, after_key=after_key,
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
iterations for _, iterations in res.metric_events
|
||||
):
|
||||
urls.update(res)
|
||||
if not after_key:
|
||||
break
|
||||
|
||||
scroll_id = res.next_scroll_id
|
||||
for task, iterations in res.metric_events:
|
||||
urls.update(ev.get("url") for it in iterations for ev in it["events"])
|
||||
|
||||
urls.discard({None})
|
||||
return urls
|
||||
|
||||
|
||||
supported_storage_types = {
|
||||
"https://": StorageType.fileserver,
|
||||
"http://": StorageType.fileserver,
|
||||
}
|
||||
|
||||
|
||||
def _schedule_for_delete(
|
||||
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
|
||||
) -> Set[str]:
|
||||
urls_per_storage = bucketize(
|
||||
urls,
|
||||
key=lambda u: first(
|
||||
type_
|
||||
for prefix, type_ in supported_storage_types.items()
|
||||
if u.startswith(prefix)
|
||||
),
|
||||
)
|
||||
urls_per_storage.pop(None, None)
|
||||
|
||||
processed_urls = set()
|
||||
for storage_type, storage_urls in urls_per_storage.items():
|
||||
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
|
||||
scheduled_to_delete = set()
|
||||
for url in storage_urls:
|
||||
folder = None
|
||||
if delete_folders:
|
||||
folder, _, _ = url.rpartition("/")
|
||||
|
||||
to_delete = folder or url
|
||||
if to_delete in scheduled_to_delete:
|
||||
processed_urls.add(url)
|
||||
continue
|
||||
|
||||
try:
|
||||
UrlToDelete(
|
||||
id=db_id(),
|
||||
company=company,
|
||||
user=user,
|
||||
url=to_delete,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
storage_type=storage_type,
|
||||
type=FileType.folder if folder else FileType.file,
|
||||
).save()
|
||||
except (DuplicateKeyError, NotUniqueError):
|
||||
existing = UrlToDelete.objects(company=company, url=to_delete).first()
|
||||
if existing:
|
||||
existing.update(
|
||||
user=user,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
retry_count=0,
|
||||
unset__last_failure_time=1,
|
||||
unset__last_failure_reason=1,
|
||||
status=DeletionStatus.created,
|
||||
)
|
||||
processed_urls.add(url)
|
||||
scheduled_to_delete.add(to_delete)
|
||||
|
||||
return processed_urls
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
company: str,
|
||||
user: str,
|
||||
task: Task,
|
||||
force: bool = False,
|
||||
update_children=True,
|
||||
return_file_urls=False,
|
||||
delete_output_models=True,
|
||||
delete_external_artifacts=True,
|
||||
) -> CleanupResult:
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
@@ -166,10 +181,14 @@ def cleanup_task(
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
models = verify_task_children_and_ouptuts(task, force)
|
||||
|
||||
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
|
||||
task, force
|
||||
)
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", False
|
||||
)
|
||||
event_urls, artifact_urls, model_urls = set(), set(), set()
|
||||
if return_file_urls:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls = collect_debug_image_urls(task.company, task.id)
|
||||
event_urls.update(collect_plot_image_urls(task.company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
@@ -178,30 +197,63 @@ def cleanup_task(
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
|
||||
model_urls = {
|
||||
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
|
||||
}
|
||||
|
||||
deleted_task_id = f"{deleted_prefix}{task.id}"
|
||||
updated_children = 0
|
||||
if update_children:
|
||||
with TimingContext("mongo", "update_task_children"):
|
||||
updated_children = Task.objects(parent=task.id).update(
|
||||
parent=deleted_task_id
|
||||
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
|
||||
|
||||
deleted_models = 0
|
||||
updated_models = 0
|
||||
for models, allow_delete in ((draft_models, True), (published_models, False)):
|
||||
if not models:
|
||||
continue
|
||||
if delete_output_models and allow_delete:
|
||||
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
|
||||
for m_id in model_ids:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls.update(collect_debug_image_urls(task.company, m_id))
|
||||
event_urls.update(collect_plot_image_urls(task.company, m_id))
|
||||
try:
|
||||
event_bll.delete_task_events(
|
||||
task.company,
|
||||
m_id,
|
||||
allow_locked=True,
|
||||
model=True,
|
||||
async_delete=async_events_delete,
|
||||
)
|
||||
except errors.bad_request.InvalidModelId as ex:
|
||||
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
|
||||
|
||||
deleted_models += Model.objects(id__in=list(model_ids)).delete()
|
||||
if in_use_model_ids:
|
||||
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
|
||||
continue
|
||||
|
||||
if update_children:
|
||||
updated_models += Model.objects(id__in=[m.id for m in models]).update(
|
||||
task=deleted_task_id
|
||||
)
|
||||
else:
|
||||
updated_children = 0
|
||||
else:
|
||||
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
|
||||
|
||||
if models.draft and delete_output_models:
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
deleted_models = models.draft.objects().delete()
|
||||
else:
|
||||
deleted_models = 0
|
||||
event_bll.delete_task_events(
|
||||
task.company, task.id, allow_locked=force, async_delete=async_events_delete
|
||||
)
|
||||
|
||||
if models.published and update_children:
|
||||
with TimingContext("mongo", "update_task_models"):
|
||||
updated_models = models.published.objects().update(task=deleted_task_id)
|
||||
else:
|
||||
updated_models = 0
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=task.id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=not in_use_model_ids and not published_models,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
@@ -217,62 +269,56 @@ def cleanup_task(
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
|
||||
def verify_task_children_and_ouptuts(
|
||||
task, force: bool
|
||||
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
|
||||
if not force:
|
||||
with TimingContext("mongo", "count_published_children"):
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "get_task_models"):
|
||||
models = TaskOutputs(
|
||||
attrgetter("ready"),
|
||||
Model,
|
||||
Model.objects(task=task.id).only("id", "task", "ready"),
|
||||
)
|
||||
if not force and models.published:
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
models=len(models.published),
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
model_fields = ["id", "ready", "uri"]
|
||||
published_models, draft_models = partition(
|
||||
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
|
||||
)
|
||||
if not force and published_models:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(published_models),
|
||||
)
|
||||
|
||||
if task.models and task.models.output:
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=output_model.id,
|
||||
)
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
models.draft.append(output_model)
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids).only(*model_fields):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=output_model.id,
|
||||
)
|
||||
published_models.append(output_model)
|
||||
else:
|
||||
draft_models.append(output_model)
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = models.draft.ids
|
||||
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
|
||||
"id", "models"
|
||||
in_use_model_ids = {}
|
||||
if draft_models:
|
||||
model_ids = {m.id for m in draft_models}
|
||||
dependent_tasks = Task.objects(models__input__model__in=list(model_ids)).only(
|
||||
"id", "models"
|
||||
)
|
||||
in_use_model_ids = model_ids & {
|
||||
m.model
|
||||
for m in chain.from_iterable(
|
||||
t.models.input for t in dependent_tasks if t.models
|
||||
)
|
||||
input_models = {
|
||||
m.model
|
||||
for m in chain.from_iterable(
|
||||
t.models.input for t in dependent_tasks if t.models
|
||||
)
|
||||
}
|
||||
if input_models:
|
||||
models.draft = DocumentGroup(
|
||||
Model, (m for m in models.draft if m.id not in input_models)
|
||||
)
|
||||
}
|
||||
|
||||
return models
|
||||
return published_models, draft_models, in_use_model_ids
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Any, Tuple, Union
|
||||
from typing import Callable, Any, Tuple, Union, Sequence
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
@@ -25,6 +25,7 @@ from apiserver.database.model.task.task import (
|
||||
)
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
@@ -83,10 +84,7 @@ def unarchive_task(
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
task_id: str, company_id: str, status_message: str, status_reason: str,
|
||||
) -> Tuple[int, dict]:
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
@@ -94,10 +92,7 @@ def dequeue_task(
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
task, company_id, status_message=status_message, status_reason=status_reason,
|
||||
)
|
||||
return 1, res
|
||||
|
||||
@@ -108,8 +103,23 @@ def enqueue_task(
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
queue_name: str = None,
|
||||
validate: bool = False,
|
||||
force: bool = False,
|
||||
) -> Tuple[int, dict]:
|
||||
if queue_id and queue_name:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either queue id or queue name should be provided"
|
||||
)
|
||||
|
||||
if queue_name:
|
||||
queue = queue_bll.get_by_name(
|
||||
company_id=company_id, queue_name=queue_name, only=("id",)
|
||||
)
|
||||
if not queue:
|
||||
queue = queue_bll.create(company_id=company_id, name=queue_name)
|
||||
queue_id = queue.id
|
||||
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
@@ -128,6 +138,7 @@ def enqueue_task(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
force=force,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
@@ -153,13 +164,41 @@ def enqueue_task(
|
||||
return 1, res
|
||||
|
||||
|
||||
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
try:
|
||||
collection_name = Task._get_collection_name()
|
||||
trash_collection_name = f"{collection_name}__trash"
|
||||
Task.aggregate(
|
||||
[
|
||||
{"$match": {"_id": {"$in": tasks}}},
|
||||
{
|
||||
"$merge": {
|
||||
"into": trash_collection_name,
|
||||
"on": "_id",
|
||||
"whenMatched": "replace",
|
||||
"whenNotMatched": "insert",
|
||||
}
|
||||
},
|
||||
],
|
||||
allow_disk_use=True,
|
||||
)
|
||||
except Exception as ex:
|
||||
log.error(f"Error copying tasks to trash {str(ex)}")
|
||||
|
||||
return Task.objects(id__in=tasks).delete()
|
||||
|
||||
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
@@ -177,26 +216,34 @@ def delete_task(
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
task,
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
if move_to_trash:
|
||||
collection_name = task._get_collection_name()
|
||||
archived_collection = "{}__trash".format(collection_name)
|
||||
task.switch_collection(archived_collection)
|
||||
try:
|
||||
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
|
||||
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
|
||||
task.save(force_insert=True)
|
||||
except Exception:
|
||||
pass
|
||||
task.switch_collection(collection_name)
|
||||
# make sure that whatever changes were done to the task are saved
|
||||
# the task itself will be deleted later in the move_tasks_to_trash operation
|
||||
task.save()
|
||||
else:
|
||||
task.delete()
|
||||
|
||||
task.delete()
|
||||
update_project_time(task.project)
|
||||
return 1, task, cleanup_res
|
||||
|
||||
@@ -204,10 +251,12 @@ def delete_task(
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
@@ -226,18 +275,23 @@ def reset_task(
|
||||
pass
|
||||
|
||||
cleaned_up = cleanup_task(
|
||||
task,
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
update_children=False,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__unique_metrics=[],
|
||||
set__metric_stats={},
|
||||
set__models__output=[],
|
||||
set__runtime={},
|
||||
unset__output__result=1,
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
@@ -351,6 +405,7 @@ def stop_task(
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
@@ -364,7 +419,21 @@ def stop_task(
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
is_queued = task.status == TaskStatus.queued
|
||||
set_stopped = (
|
||||
is_queued
|
||||
or TaskSystemTags.development in task.system_tags
|
||||
or not is_run_by_worker(task)
|
||||
)
|
||||
|
||||
if set_stopped:
|
||||
if is_queued:
|
||||
try:
|
||||
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
|
||||
@@ -9,7 +9,6 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
@@ -55,7 +54,7 @@ class ChangeStatusRequest(object):
|
||||
|
||||
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
||||
with translate_errors_context():
|
||||
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
||||
params = fields.copy()
|
||||
params.update(control)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.users import CreateRequest
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -12,7 +14,7 @@ class UserBLL:
|
||||
if user_id and User.objects(id=user_id).only("id"):
|
||||
raise errors.bad_request.UserIdExists(id=user_id)
|
||||
|
||||
user = User(**request.to_struct())
|
||||
user = User(**request.to_struct(), created=datetime.utcnow())
|
||||
user.save(force_insert=True)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
from boltons.iterutils import partition
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
@@ -25,7 +27,6 @@ from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
@@ -51,6 +52,7 @@ class WorkerBLL:
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -76,7 +78,7 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidUserId(**query)
|
||||
company = Company.objects(id=company_id).only("id", "name").first()
|
||||
if not company:
|
||||
raise server_error.InternalError("invalid company", company=company_id)
|
||||
raise bad_request.InvalidId("invalid company", company=company_id)
|
||||
|
||||
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
|
||||
if len(queue_objs) < len(queues):
|
||||
@@ -95,9 +97,10 @@ class WorkerBLL:
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
self._save_worker_data(entry)
|
||||
|
||||
return entry
|
||||
|
||||
@@ -109,15 +112,20 @@ class WorkerBLL:
|
||||
:param worker: worker ID
|
||||
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
|
||||
"""
|
||||
with TimingContext("redis", "workers_unregister"):
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res:
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res and not config.get("apiserver.workers.auto_unregister", False):
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
ip: str,
|
||||
report: StatusReportRequest,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
@@ -138,12 +146,14 @@ class WorkerBLL:
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
if system_tags is not None:
|
||||
entry.system_tags = system_tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
company_id=company_id,
|
||||
company_name=entry.company.name,
|
||||
worker=report.worker,
|
||||
worker=entry.key,
|
||||
timestamp=report.timestamp,
|
||||
task=report.task,
|
||||
machine_stats=report.machine_stats,
|
||||
@@ -176,7 +186,9 @@ class WorkerBLL:
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
entry.project = IdNameEntry(
|
||||
id=project.id, name=project.name
|
||||
)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
@@ -189,7 +201,11 @@ class WorkerBLL:
|
||||
self._save_worker(entry)
|
||||
|
||||
def get_all(
|
||||
self, company_id: str, last_seen: Optional[int] = None
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: Optional[int] = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""
|
||||
Get all the company workers that were active during the last_seen period
|
||||
@@ -198,7 +214,7 @@ class WorkerBLL:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workers = self._get(company_id)
|
||||
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
|
||||
except Exception as e:
|
||||
raise server_error.DataError("failed loading worker entries", err=e.args[0])
|
||||
|
||||
@@ -213,13 +229,22 @@ class WorkerBLL:
|
||||
return workers
|
||||
|
||||
def get_all_with_projection(
|
||||
self, company_id: str, last_seen: int
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: int,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(company_id=company_id, last_seen=last_seen),
|
||||
self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -258,7 +283,7 @@ class WorkerBLL:
|
||||
tasks_info = {
|
||||
task.id: task
|
||||
for task in Task.objects(id__in=task_ids).only(
|
||||
"name", "started", "last_iteration"
|
||||
"name", "started", "last_iteration", "active_duration"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -283,11 +308,7 @@ class WorkerBLL:
|
||||
if helper.task_id:
|
||||
task = tasks_info.get(helper.task_id, None)
|
||||
if task:
|
||||
worker.task.running_time = (
|
||||
int((datetime.utcnow() - task.started).total_seconds() * 1000)
|
||||
if task.started
|
||||
else 0
|
||||
)
|
||||
worker.task.running_time = (task.active_duration or 0) * 1000
|
||||
worker.task.last_iteration = task.last_iteration
|
||||
|
||||
update_queue_entries(worker.queue)
|
||||
@@ -314,8 +335,7 @@ class WorkerBLL:
|
||||
"""
|
||||
key = self._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
with TimingContext("redis", "get_worker"):
|
||||
data = self.redis.get(key)
|
||||
data = self.redis.get(key)
|
||||
|
||||
if data:
|
||||
try:
|
||||
@@ -342,24 +362,119 @@ class WorkerBLL:
|
||||
|
||||
raise bad_request.InvalidWorkerId(worker=worker)
|
||||
|
||||
@staticmethod
|
||||
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers.{tags_field}_{company}_{tag}"
|
||||
|
||||
@staticmethod
|
||||
def _get_all_workers_key(company: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers_{company}"
|
||||
|
||||
def _save_worker_data(self, entry: WorkerEntry):
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
company_id = entry.company.id
|
||||
expiration = int(time()) + entry.register_timeout
|
||||
worker_item = {entry.key: expiration}
|
||||
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
|
||||
for tags, tags_field in (
|
||||
(entry.tags, "tags"),
|
||||
(entry.system_tags, "systemtags"),
|
||||
):
|
||||
for tag in tags:
|
||||
name = self._get_tagged_workers_key(company_id, tags_field, tag)
|
||||
self.redis.zadd(name, worker_item)
|
||||
|
||||
def _save_worker(self, entry: WorkerEntry) -> None:
|
||||
"""Save worker entry in Redis"""
|
||||
try:
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
self._save_worker_data(entry)
|
||||
except Exception:
|
||||
msg = "Failed saving worker entry"
|
||||
log.exception(msg)
|
||||
|
||||
def _get(
|
||||
self, company: str, user: str = "*", worker_id: str = "*"
|
||||
self,
|
||||
company: str,
|
||||
user: str = "*",
|
||||
worker_id: str = "*",
|
||||
user_tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
match = self._get_worker_key(company, user, worker_id)
|
||||
with TimingContext("redis", "workers_get_all"):
|
||||
res = self.redis.scan_iter(match)
|
||||
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
|
||||
|
||||
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
||||
if user == "*":
|
||||
return in_keys
|
||||
user_bytes = user.encode()
|
||||
return {k for k in in_keys if user_bytes in k}
|
||||
|
||||
if user_tags or system_tags:
|
||||
worker_keys = set()
|
||||
for tags, tags_field in (
|
||||
(user_tags, "tags"),
|
||||
(system_tags, "systemtags"),
|
||||
):
|
||||
if not tags:
|
||||
continue
|
||||
timestamp = int(time())
|
||||
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
||||
if include:
|
||||
tagged_workers = set()
|
||||
for tag in include:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
tagged_workers.update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
tagged_workers = filter_by_user(tagged_workers)
|
||||
worker_keys = (
|
||||
worker_keys.intersection(tagged_workers)
|
||||
if worker_keys
|
||||
else tagged_workers
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
if exclude:
|
||||
if not worker_keys:
|
||||
all_workers_key = self._get_all_workers_key(company)
|
||||
self.redis.zremrangebyscore(
|
||||
all_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
||||
worker_keys = filter_by_user(worker_keys)
|
||||
if not worker_keys:
|
||||
return []
|
||||
for tag in exclude:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag[1:]
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.difference_update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
else:
|
||||
match = self._get_worker_key(company, user, "*")
|
||||
worker_keys = self.redis.scan_iter(match)
|
||||
|
||||
entries = []
|
||||
for key in worker_keys:
|
||||
data = self.redis.get(key)
|
||||
if data:
|
||||
entries.append(WorkerEntry.from_json(data))
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
|
||||
@@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -20,7 +19,7 @@ class WorkerStats:
|
||||
@staticmethod
|
||||
def worker_stats_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"worker_stats_{company_id}_"
|
||||
return f"worker_stats_{company_id.lower()}_"
|
||||
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
@@ -126,7 +125,7 @@ class WorkerStats:
|
||||
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
|
||||
es_req["query"] = {"bool": {"must": query_terms}}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
return self._extract_results(data, request.items, request.split_by_variant)
|
||||
@@ -223,9 +222,7 @@ class WorkerStats:
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_worker_activity_report"
|
||||
):
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if "aggregations" not in data:
|
||||
|
||||
@@ -79,6 +79,8 @@ class BasicConfig:
|
||||
def logger(self, name: str) -> logging.Logger:
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
if name == "__init__" and Path(name).parent.stem:
|
||||
name = Path(name).parent.stem
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
debug: false # Debug mode
|
||||
pretty_json: false # prettify json response
|
||||
return_stack: true # return stack trace on error
|
||||
log_calls: true # Log API Calls
|
||||
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
|
||||
|
||||
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
|
||||
# valid values are:
|
||||
@@ -79,6 +79,11 @@
|
||||
max_age: 99999999999
|
||||
}
|
||||
|
||||
# provide a cookie domain override per company
|
||||
# cookies_domain_override {
|
||||
# <company-id>: <domain>
|
||||
# }
|
||||
|
||||
# # A list of fixed users
|
||||
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
|
||||
# fixed_users {
|
||||
@@ -107,6 +112,8 @@
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
|
||||
auto_unregister: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
@@ -139,4 +146,11 @@
|
||||
max_backoff_sec: 5
|
||||
}
|
||||
|
||||
getting_started_info {
|
||||
"agentName": "clearml",
|
||||
"configure": "clearml-init",
|
||||
"install": "pip install clearml",
|
||||
"packageName": "clearml"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
fileserver = "http://localhost:8081"
|
||||
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
|
||||
4
apiserver/config/default/services/_mongo.conf
Normal file
4
apiserver/config/default/services/_mongo.conf
Normal file
@@ -0,0 +1,4 @@
|
||||
max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
@@ -0,0 +1,12 @@
|
||||
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
|
||||
# otherwise they are returned to a client for the client side delete
|
||||
enabled: false
|
||||
max_retries: 3
|
||||
retry_timeout_sec: 60
|
||||
|
||||
fileserver {
|
||||
# fileserver url prefixes. Evaluated in the order of priority
|
||||
# Can be in the form <schema>://host:port/path or /path
|
||||
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
|
||||
timeout_sec: 300
|
||||
}
|
||||
@@ -12,11 +12,26 @@ events_retrieval {
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
# If set then max_metrics_count and max_variants_count are calculated dynamically on user data
|
||||
dynamic_metrics_count: true
|
||||
|
||||
# The percentage from the ES aggs limit (10000) to use for the max_metrics and max_variants calculation
|
||||
dynamic_metrics_count_threshold: 80
|
||||
|
||||
# the max amount of metrics to aggregate on
|
||||
max_metrics_count: 100
|
||||
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
|
||||
debug_images {
|
||||
# Allow to return the debug images for the variants with uninitialized valid iterations border
|
||||
allow_uninitialized_variants: true
|
||||
}
|
||||
|
||||
max_raw_scalars_size: 200000
|
||||
|
||||
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
@@ -24,4 +39,7 @@ events_retrieval {
|
||||
validate_plot_str: false
|
||||
|
||||
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||
plot_compression_threshold: 100000
|
||||
plot_compression_threshold: 100000
|
||||
|
||||
# async events delete threshold
|
||||
max_async_deleted_events_per_sec: 1000
|
||||
7
apiserver/config/default/services/models.conf
Normal file
7
apiserver/config/default/services/models.conf
Normal file
@@ -0,0 +1,7 @@
|
||||
metadata_values {
|
||||
# maximal amount of distinct model values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
8
apiserver/config/default/services/queues.conf
Normal file
8
apiserver/config/default/services/queues.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
metrics_before_from_date: 3600
|
||||
# interval in seconds to update queue metrics. Put 0 to disable
|
||||
metrics_refresh_interval_sec: 300
|
||||
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
|
||||
# or search_hidden is set
|
||||
hidden_tags: [k8s-glue]
|
||||
}
|
||||
@@ -19,4 +19,11 @@ hyperparam_values {
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
}
|
||||
|
||||
# the maximum amount of unique last metrics/variants combinations
|
||||
# for which the last values are stored in a task
|
||||
max_last_metrics: 2000
|
||||
|
||||
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
|
||||
async_events_delete: false
|
||||
|
||||
@@ -28,6 +28,11 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"MONGODB_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
|
||||
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
|
||||
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
|
||||
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
@@ -47,13 +52,26 @@ class DatabaseFactory:
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
|
||||
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
|
||||
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
|
||||
|
||||
if override_connection_string:
|
||||
log.info(f"Using override mongodb connection string template {override_connection_string}")
|
||||
else:
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
if override_username:
|
||||
log.info(f"Using override mongodb username {override_username}")
|
||||
if override_password:
|
||||
log.info(f"Using override mongodb password ******")
|
||||
if override_query:
|
||||
log.info(f"Using override mongodb query {override_query}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
@@ -62,11 +80,21 @@ class DatabaseFactory:
|
||||
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_connection_string:
|
||||
con_str = f"{override_connection_string.rstrip('/')}/{key}"
|
||||
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
|
||||
entry.host = con_str
|
||||
else:
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_username:
|
||||
entry.host = furl(entry.host).set(username=override_username).url
|
||||
if override_password:
|
||||
entry.host = furl(entry.host).set(password=override_password).url
|
||||
if override_query:
|
||||
entry.host = furl(entry.host).set(query=override_query).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
|
||||
@@ -166,7 +166,10 @@ class MongoEngineErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.InternalError)
|
||||
def invalid_query_error(cls, e, message, **_):
|
||||
pass
|
||||
if e.args:
|
||||
inner = e.args[0]
|
||||
if isinstance(inner, LookUpError):
|
||||
cls.lookup_error(inner, message)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -176,6 +176,13 @@ class SafeMapField(MapField, DictValidationMixin):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class NullableStringField(StringField):
|
||||
def validate(self, value):
|
||||
if value is None:
|
||||
return
|
||||
super(NullableStringField, self).validate(value)
|
||||
|
||||
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
self._safe_validate(value)
|
||||
|
||||
@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
hidden = "hidden"
|
||||
|
||||
@@ -48,7 +48,9 @@ class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
|
||||
@@ -1,26 +1,42 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
||||
from functools import reduce, partial
|
||||
from typing import (
|
||||
Collection,
|
||||
Sequence,
|
||||
Union,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Mapping,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.projection import ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
from apiserver.database.query import RegexQ, RegexWrapper
|
||||
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_choices,
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@@ -40,17 +56,25 @@ class ProperDictMixin(object):
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
exclude=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
only=only,
|
||||
extra_dict=extra_dict,
|
||||
exclude=exclude,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def properize_dict(
|
||||
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
|
||||
cls,
|
||||
d,
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
exclude=None,
|
||||
normalize_id=True,
|
||||
):
|
||||
res = d
|
||||
if normalize_id and "_id" in res:
|
||||
@@ -61,6 +85,9 @@ class ProperDictMixin(object):
|
||||
res = project_dict(res, only)
|
||||
if extra_dict:
|
||||
res.update(extra_dict)
|
||||
if exclude:
|
||||
exclude_fields_from_dict(res, exclude)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -70,6 +97,9 @@ class GetMixin(PropsMixin):
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_start_key = "start"
|
||||
_size_key = "size"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
@@ -77,6 +107,7 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
@@ -103,45 +134,106 @@ class GetMixin(PropsMixin):
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_mongo_op = "nin"
|
||||
|
||||
_default = "in"
|
||||
default_mongo_op = "in"
|
||||
_ops = {
|
||||
# op -> (mongo_op, sticky)
|
||||
"not": ("nin", False),
|
||||
"nop": (default_mongo_op, False),
|
||||
"all": ("all", True),
|
||||
"and": ("all", True),
|
||||
"any": (default_mongo_op, True),
|
||||
"or": (default_mongo_op, True),
|
||||
}
|
||||
_next = _default
|
||||
_sticky = False
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
self._support_legacy = legacy
|
||||
self.allow_empty = False
|
||||
|
||||
def key(self, v):
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
op = (
|
||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
return op
|
||||
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"][0]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next, self._sticky = self._ops.get(
|
||||
v[len(self.op_prefix) :], (self._default, self._sticky)
|
||||
)
|
||||
self.allow_empty = True
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
if not self._sticky:
|
||||
self._next = self._default
|
||||
return next_
|
||||
op = self._get_op(v)
|
||||
if op is not None:
|
||||
# operator - set state and return None
|
||||
self._current_op, self._sticky = self._ops.get(
|
||||
op, (self.default_mongo_op, self._sticky)
|
||||
)
|
||||
return None
|
||||
elif self._current_op:
|
||||
current_op = self._current_op
|
||||
if not self._sticky:
|
||||
self._current_op = None
|
||||
return current_op
|
||||
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
|
||||
self._current_op = None
|
||||
return False
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
return self.default_mongo_op
|
||||
|
||||
def get_global_op(self, data: Sequence[str]) -> int:
|
||||
op_to_res = {
|
||||
"in": Q.OR,
|
||||
"all": Q.AND,
|
||||
}
|
||||
data = (x for x in data if x is not None)
|
||||
first_op = (
|
||||
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||
)
|
||||
return op_to_res.get(first_op, self.default_mongo_op)
|
||||
|
||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||
actions = {}
|
||||
|
||||
for val in data:
|
||||
key = self._key(val)
|
||||
if key is None:
|
||||
continue
|
||||
elif self._support_legacy and key is False:
|
||||
key = self._legacy_exclude_mongo_op
|
||||
val = val[len(self._legacy_exclude_prefix) :]
|
||||
actions.setdefault(key, []).append(val)
|
||||
|
||||
return actions
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
class GetManyScrollState(ProperDictMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": False}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
position = IntField(default=0)
|
||||
|
||||
_cache_manager = None
|
||||
|
||||
@classmethod
|
||||
def get_cache_manager(cls):
|
||||
if not cls._cache_manager:
|
||||
cls._cache_manager = RedisCacheManager(
|
||||
state_class=cls.GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
return cls._cache_manager
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Union["GetMixin", Document],
|
||||
@@ -240,7 +332,9 @@ class GetMixin(PropsMixin):
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
|
||||
used instead of Q. Otherwise we can and up with some combination that is not processed according to
|
||||
RegexQ rules
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
@@ -273,28 +367,44 @@ class GetMixin(PropsMixin):
|
||||
).items():
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
dict_query[field] = data
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= RegexQ(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = field if not modifier else "__".join((field, modifier))
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
# date time fields also support simplified range queries. Check if this is the case
|
||||
if len(data) == 2 and not any(
|
||||
d.startswith(mod)
|
||||
for d in data
|
||||
if d is not None
|
||||
for mod in ACCESS_MODIFIER
|
||||
):
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
else:
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = (
|
||||
field
|
||||
if not modifier
|
||||
else "__".join((field, modifier))
|
||||
)
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
@@ -308,22 +418,31 @@ class GetMixin(PropsMixin):
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
),
|
||||
),
|
||||
data.fields,
|
||||
Q()
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Return a range query for the provided field. The data should contain min and max values
|
||||
Both intervals are included. For open range queries either min or max can be None
|
||||
@@ -347,14 +466,14 @@ class GetMixin(PropsMixin):
|
||||
if max_val is not None:
|
||||
query[f"{mongoengine_field}__lte"] = max_val
|
||||
|
||||
q = Q(**query)
|
||||
q = RegexQ(**query)
|
||||
if min_val is None:
|
||||
q |= Q(**{mongoengine_field: None})
|
||||
q |= RegexQ(**{mongoengine_field: None})
|
||||
|
||||
return q
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
@@ -366,30 +485,31 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
data = [data]
|
||||
# raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
global_op = helper.get_global_op(data)
|
||||
actions = helper.get_actions(data)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
|
||||
)
|
||||
queries = [
|
||||
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
|
||||
for action in filter(None, actions)
|
||||
]
|
||||
|
||||
if not allow_empty:
|
||||
if not queries:
|
||||
q = RegexQ()
|
||||
else:
|
||||
q = RegexQCombination(operation=global_op, children=queries)
|
||||
|
||||
if not helper.allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
| RegexQ(**{f"{mongoengine_field}__exists": False})
|
||||
| RegexQ(**{mongoengine_field: []})
|
||||
| RegexQ(**{mongoengine_field: None})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -402,6 +522,8 @@ class GetMixin(PropsMixin):
|
||||
def validate_order_by(cls, parameters, search_text) -> Sequence:
|
||||
"""
|
||||
Validate and extract order_by params as a list
|
||||
If ordering is specified then make sure that id field is part of it
|
||||
This guarantees the unique order when paging
|
||||
"""
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if not order_by:
|
||||
@@ -414,30 +536,47 @@ class GetMixin(PropsMixin):
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
|
||||
if not any(id_field in order_by for id_field in ("id", "-id")):
|
||||
order_by.append("id")
|
||||
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
|
||||
"""
|
||||
Validate and extract paging info from from the provided dictionary. Supports default values.
|
||||
If page is specified then it should be non-negative, if page size is specified then it should be positive
|
||||
If page size is specified and page is not then 0 page is assumed
|
||||
If page is specified then page size should be specified too
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
start = parameters.get(cls._start_key)
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
|
||||
page_size = parameters.get("page_size", default_page_size or max_page_size)
|
||||
if page_size is not None and page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
if page_size is not None:
|
||||
page = page or 0
|
||||
page_size = min(page_size, max_page_size)
|
||||
return page * page_size, page_size
|
||||
|
||||
if page is not None:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
@@ -481,6 +620,54 @@ class GetMixin(PropsMixin):
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def validate_scroll_size(cls, query_dict: dict) -> int:
|
||||
size = query_dict.get(cls._size_key)
|
||||
if not size or not isinstance(size, int) or size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Integer size parameter greater than 1 should be provided when working with scroll"
|
||||
)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
ret_params: dict,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Retrieves the data by calling the provided data_getter api
|
||||
If scroll parameters are specified then put the query_dict 'start' parameter to the last
|
||||
scroll position and continue retrievals from that position
|
||||
If refresh_scroll is requested then bring once more the data from the beginning
|
||||
till the current scroll position
|
||||
In the end the scroll position is updated and accumulated frames are returned
|
||||
"""
|
||||
query_dict = query_dict or {}
|
||||
state: Optional[cls.GetManyScrollState] = None
|
||||
if "scroll_id" in query_dict:
|
||||
size = cls.validate_scroll_size(query_dict)
|
||||
state = cls.get_cache_manager().get_or_create_state_core(
|
||||
query_dict.get("scroll_id")
|
||||
)
|
||||
if query_dict.get("refresh_scroll"):
|
||||
query_dict[cls._size_key] = max(state.position, size)
|
||||
state.position = 0
|
||||
query_dict[cls._start_key] = state.position
|
||||
|
||||
data = data_getter()
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
if state:
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.get_cache_manager().set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
@@ -491,6 +678,8 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
projection_fields: dict = None,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
@@ -526,6 +715,8 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
projection_fields=projection_fields,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
@@ -545,6 +736,45 @@ class GetMixin(PropsMixin):
|
||||
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_count(
|
||||
cls: Union["GetMixin", Document],
|
||||
company,
|
||||
query_dict: dict = None,
|
||||
query_options: QueryParameterOptions = None,
|
||||
query: Q = None,
|
||||
allow_public=False,
|
||||
) -> int:
|
||||
_query = cls._get_combined_query(
|
||||
company=company,
|
||||
query_dict=query_dict,
|
||||
query_options=query_options,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
return cls.objects(_query).count()
|
||||
|
||||
@classmethod
|
||||
def _get_combined_query(
|
||||
cls,
|
||||
company,
|
||||
query_dict: dict = None,
|
||||
query_options: QueryParameterOptions = None,
|
||||
query: Q = None,
|
||||
allow_public=False,
|
||||
) -> Q:
|
||||
if query_dict is not None:
|
||||
q = cls.prepare_query(
|
||||
parameters=query_dict,
|
||||
company=company,
|
||||
parameters_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
else:
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
|
||||
return (q & query) if query else q
|
||||
|
||||
@classmethod
|
||||
def get_many(
|
||||
cls,
|
||||
@@ -556,6 +786,8 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
projection_fields: dict = None,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
@@ -589,23 +821,25 @@ class GetMixin(PropsMixin):
|
||||
if override_collation:
|
||||
break
|
||||
|
||||
if query_dict is not None:
|
||||
q = cls.prepare_query(
|
||||
parameters=query_dict,
|
||||
company=company,
|
||||
parameters_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
else:
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
_query = (q & query) if query else q
|
||||
_query = cls._get_combined_query(
|
||||
company=company,
|
||||
query_dict=query_dict,
|
||||
query_options=query_options,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
data_getter = partial(
|
||||
cls._get_many_override_none_ordering,
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
projection_fields=projection_fields,
|
||||
)
|
||||
return cls.get_data_with_scroll_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
@@ -613,6 +847,7 @@ class GetMixin(PropsMixin):
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
projection_fields=projection_fields,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -637,6 +872,7 @@ class GetMixin(PropsMixin):
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
override_collation=None,
|
||||
projection_fields: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
@@ -658,7 +894,7 @@ class GetMixin(PropsMixin):
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
if order_by and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_by[0])
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -679,9 +915,12 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
if projection_fields:
|
||||
qs = qs.fields(**projection_fields)
|
||||
|
||||
if start is not None and size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
qs = qs.skip(start).limit(size)
|
||||
|
||||
return qs
|
||||
|
||||
@@ -720,6 +959,7 @@ class GetMixin(PropsMixin):
|
||||
parameters: dict = None,
|
||||
override_projection: Collection[str] = None,
|
||||
override_collation: dict = None,
|
||||
projection_fields: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
@@ -742,7 +982,10 @@ class GetMixin(PropsMixin):
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
if size is not None and size <= 0:
|
||||
return []
|
||||
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -774,25 +1017,35 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
if projection_fields:
|
||||
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
|
||||
|
||||
if start is None or not size:
|
||||
return [
|
||||
obj.to_proper_dict(only=include, exclude=exclude)
|
||||
for qs in query_sets
|
||||
for obj in qs
|
||||
]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
last_set = len(query_sets) - 1
|
||||
for i, qs in enumerate(query_sets):
|
||||
last_size = len(ret)
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include, exclude=exclude)
|
||||
for obj in (qs.skip(start) if start else qs).limit(size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
added = len(ret) - last_size
|
||||
|
||||
if added > 0:
|
||||
start = 0
|
||||
size = max(0, size - added)
|
||||
elif i != last_set:
|
||||
start -= min(start, qs.count())
|
||||
|
||||
if size <= 0:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
@@ -13,18 +10,21 @@ from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
class Model(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -33,11 +33,10 @@ class Model(DbModelMixin, Document):
|
||||
"project",
|
||||
"task",
|
||||
"last_update",
|
||||
"metadata.key",
|
||||
"metadata.type",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "uri"),
|
||||
{
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
@@ -66,6 +65,7 @@ class Model(DbModelMixin, Document):
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
@@ -73,8 +73,6 @@ class Model(DbModelMixin, Document):
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
@@ -91,6 +89,9 @@ class Model(DbModelMixin, Document):
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
return self.company or self.company_origin or ""
|
||||
|
||||
@@ -9,8 +9,9 @@ from apiserver.database.model.base import GetMixin
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
pattern_fields=("name", "basename", "description"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
range_fields=("last_update",),
|
||||
)
|
||||
|
||||
meta = {
|
||||
@@ -20,6 +21,7 @@ class Project(AttributedDocument):
|
||||
"parent",
|
||||
"path",
|
||||
("company", "name"),
|
||||
("company", "basename"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
@@ -36,6 +38,7 @@ class Project(AttributedDocument):
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
basename = StrippedStringField(required=True)
|
||||
description = StringField()
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin, AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
@@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
""" Task ID """
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
""" Added to the queue """
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"indexes": ["metadata.key", "metadata.type"],
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
tags = SafeSortedListField(
|
||||
StringField(required=True), default=list, user_set_allowed=True
|
||||
)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from apiserver.database.fields import (
|
||||
UnionField,
|
||||
SafeSortedListField,
|
||||
EmbeddedDocumentListField,
|
||||
NullableStringField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
@@ -158,11 +159,10 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"execution.parameters.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
"hyperparams.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -175,6 +175,8 @@ class Task(AttributedDocument):
|
||||
"active_duration",
|
||||
"parent",
|
||||
"project",
|
||||
"last_update",
|
||||
"status_changed",
|
||||
"models.input.model",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@@ -183,7 +185,10 @@ class Task(AttributedDocument):
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{"fields": ["company", "project"], "collation": _numeric_locale},
|
||||
{
|
||||
"fields": ["company", "project"],
|
||||
"collation": AttributedDocument._numeric_locale,
|
||||
},
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
@@ -218,6 +223,7 @@ class Task(AttributedDocument):
|
||||
"status",
|
||||
"project",
|
||||
"parent",
|
||||
"hyperparams.*",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
@@ -232,7 +238,7 @@ class Task(AttributedDocument):
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_message = StringField(user_set_allowed=True)
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
@@ -253,6 +259,7 @@ class Task(AttributedDocument):
|
||||
last_change = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
@@ -260,7 +267,7 @@ class Task(AttributedDocument):
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
models: Models = EmbeddedDocumentField(Models, default=Models)
|
||||
container = SafeMapField(field=StringField(default=""))
|
||||
container = SafeMapField(field=NullableStringField())
|
||||
enqueue_status = StringField(
|
||||
choices=get_options(TaskStatus), exclude_by_default=True
|
||||
)
|
||||
|
||||
51
apiserver/database/model/url_to_delete.py
Normal file
51
apiserver/database/model/url_to_delete.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import StringField, DateTimeField, IntField, EnumField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import AttributedDocument
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
fileserver = "fileserver"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
file = "file"
|
||||
folder = "folder"
|
||||
|
||||
|
||||
class DeletionStatus(str, Enum):
|
||||
created = "created"
|
||||
retrying = "retrying"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class UrlToDelete(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"url": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "user", "task"),
|
||||
"storage_type",
|
||||
"created",
|
||||
"retry_count",
|
||||
"type",
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
url = StringField(required=True, unique_with="company")
|
||||
task = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
storage_type = EnumField(StorageType, default=StorageType.unknown)
|
||||
type = EnumField(FileType, default=FileType.file)
|
||||
retry_count = IntField(default=0)
|
||||
last_failure_time = DateTimeField()
|
||||
last_failure_reason = StringField()
|
||||
status = EnumField(DeletionStatus, default=DeletionStatus.created)
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine import Document, StringField, DynamicField
|
||||
from mongoengine import Document, StringField, DynamicField, DateTimeField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
@@ -20,3 +20,4 @@ class User(DbModelMixin, Document):
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = DynamicField(default="", exclude_by_default=True)
|
||||
created = DateTimeField()
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
|
||||
import dpath.path
|
||||
from typing import Sequence, Dict, Callable
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.props import PropsMixin
|
||||
@@ -11,65 +9,6 @@ from apiserver.database.props import PropsMixin
|
||||
SEP = "."
|
||||
|
||||
|
||||
def project_dict(data, projection, separator=SEP):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
:param projection: List of dictionary paths (each a string with field names separated using a separator)
|
||||
:param separator: Separator (default is '.')
|
||||
:return: A new dictionary containing only the projected parts from the original dictionary
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
result = {}
|
||||
|
||||
def copy_path(path_parts, source, destination):
|
||||
src, dst = source, destination
|
||||
try:
|
||||
for depth, path_part in enumerate(path_parts[:-1]):
|
||||
src_part = src[path_part]
|
||||
if isinstance(src_part, dict):
|
||||
src = src_part
|
||||
dst = dst.setdefault(path_part, {})
|
||||
elif isinstance(src_part, (list, tuple)):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError(
|
||||
"Incompatible destination type %s for %s (list expected)"
|
||||
% (type(dst), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError(
|
||||
"Destination list length differs from source length for %s"
|
||||
% separator.join(path_parts[: depth + 1])
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported projection type %s for %s"
|
||||
% (type(src), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
except KeyError:
|
||||
# Projection field not in source, no biggie.
|
||||
pass
|
||||
return destination
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator), source=data, destination=result
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _ReferenceProxy(dict):
|
||||
def __init__(self, id):
|
||||
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
|
||||
@@ -110,9 +49,6 @@ class ProjectionHelper(object):
|
||||
self._ref_projection = None
|
||||
self._proxy_manager = _ProxyManager()
|
||||
|
||||
# Cached dpath paths for each of the result documents
|
||||
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
|
||||
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
@@ -275,25 +211,26 @@ class ProjectionHelper(object):
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
globlist = norm_path.strip(SEP).split(SEP)
|
||||
|
||||
obj_paths = self._cached_results_paths.get(id(obj))
|
||||
if obj_paths is None:
|
||||
obj_paths = self._cached_results_paths[id(obj)] = list(
|
||||
dpath.path.paths(obj, dirs=True, skip=True)
|
||||
)
|
||||
|
||||
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
|
||||
|
||||
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
|
||||
def _search_and_replace(target: dict, p: Sequence[str]) -> Sequence[str]:
|
||||
parent = None
|
||||
target = obj
|
||||
for part in p:
|
||||
parent = target
|
||||
target = target[part[0]]
|
||||
if parent and factory:
|
||||
parent[p[-1][0]] = factory(target)
|
||||
return target
|
||||
for idx, part in enumerate(p):
|
||||
if isinstance(target, dict) and part in target:
|
||||
parent = target
|
||||
target = target[part]
|
||||
elif isinstance(target, list) and part == "*":
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_search_and_replace(t, p[idx + 1 :]) for t in target
|
||||
)
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return [search_and_replace(p) for p in paths]
|
||||
if parent and factory:
|
||||
parent[p[-1]] = factory(target)
|
||||
return [target]
|
||||
|
||||
return _search_and_replace(obj, globlist)
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,12 @@ from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocument,
|
||||
Document,
|
||||
)
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
@@ -25,6 +30,13 @@ class PropsMixin(object):
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
_document_classes = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, (Document, EmbeddedDocument)):
|
||||
PropsMixin._document_classes[cls._class_name] = cls
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
@@ -57,8 +69,14 @@ class PropsMixin(object):
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
|
||||
if v == "self":
|
||||
return cls_.owner_document
|
||||
|
||||
doc_cls = PropsMixin._document_classes.get(v)
|
||||
if doc_cls:
|
||||
return doc_cls
|
||||
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
@@ -72,7 +90,7 @@ class PropsMixin(object):
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
".".join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
@@ -80,10 +98,10 @@ class PropsMixin(object):
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
|
||||
|
||||
return fields
|
||||
|
||||
@@ -94,7 +112,7 @@ class PropsMixin(object):
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
@@ -103,7 +121,7 @@ class PropsMixin(object):
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
raise ValueError("Invalid field path %s" % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
@@ -119,7 +137,7 @@ class PropsMixin(object):
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
translated_parts.append("*")
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
@@ -128,7 +146,7 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
fields = cls._get_fields_with_attr(cls, "reference_field")
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@@ -143,12 +161,12 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
def get_dpath_translated_path(cls, path, separator="."):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any, Mapping
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
@@ -203,18 +203,22 @@ def _names_set(*names: str) -> Set[str]:
|
||||
return set(names) | set(f"-{name}" for name in names)
|
||||
|
||||
|
||||
system_tag_names = {
|
||||
_system_tag_names = {
|
||||
"model": _names_set("active", "archived"),
|
||||
"project": _names_set("archived", "public", "default"),
|
||||
"task": _names_set("active", "archived", "development"),
|
||||
"queue": _names_set("default"),
|
||||
}
|
||||
|
||||
system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
_system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
|
||||
|
||||
def partition_tags(
|
||||
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
|
||||
entity: str,
|
||||
tags: Sequence[str],
|
||||
system_tags: Optional[Sequence[str]] = (),
|
||||
system_tag_names: Mapping = _system_tag_names,
|
||||
system_tag_prefixes: Mapping = _system_tag_prefixes,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Partition the given tags sequence into system and user-defined tags
|
||||
|
||||
@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
|
||||
@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
es = Elasticsearch(
|
||||
hosts=cluster_conf.get("hosts", None),
|
||||
http_auth=es_factory.get_credentials("events", cluster_conf),
|
||||
**cluster_conf.get("args", {})
|
||||
)
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
@@ -109,5 +113,7 @@ def init_es_data():
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
http_auth = es_factory.get_credentials(name)
|
||||
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
@@ -34,6 +35,12 @@
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"model_event": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
@@ -21,6 +22,10 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"ELASTIC_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
|
||||
|
||||
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
log.info(f"Using override elastic host {OVERRIDE_HOST}")
|
||||
@@ -29,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
|
||||
if OVERRIDE_USERNAME:
|
||||
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
|
||||
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
if OVERRIDE_PASSWORD:
|
||||
log.info("Using override elastic password ********")
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
@@ -48,6 +61,10 @@ class InvalidClusterConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingPasswordForElasticUser(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ESFactory:
|
||||
@classmethod
|
||||
def connect(cls, cluster_name):
|
||||
@@ -65,22 +82,45 @@ class ESFactory:
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
|
||||
http_auth = cls.get_credentials(cluster_name)
|
||||
|
||||
args = cluster_config.get("args", {})
|
||||
_instances[cluster_name] = Elasticsearch(
|
||||
hosts=hosts, transport_class=Transport, **args
|
||||
hosts=hosts, http_auth=http_auth, **args
|
||||
)
|
||||
|
||||
return _instances[cluster_name]
|
||||
|
||||
@classmethod
|
||||
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
|
||||
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
|
||||
if not cluster_config.get("secure", True):
|
||||
return None
|
||||
|
||||
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
||||
if not elastic_user:
|
||||
return None
|
||||
|
||||
elastic_password = OVERRIDE_PASSWORD or config.get(
|
||||
"secure.elastic.password", None
|
||||
)
|
||||
if not elastic_password:
|
||||
raise MissingPasswordForElasticUser(
|
||||
f"cluster={cluster_name}, username={elastic_user}"
|
||||
)
|
||||
|
||||
return elastic_user, elastic_password
|
||||
|
||||
@classmethod
|
||||
def get_all_cluster_names(cls):
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
@classmethod
|
||||
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_cluster_config(cls, cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
@@ -97,7 +137,7 @@ class ESFactory:
|
||||
for entry in cluster_config.get("hosts", []):
|
||||
entry[key] = value
|
||||
|
||||
host, port = cls.get_override(cluster_name)
|
||||
host, port = cls.get_override_host(cluster_name)
|
||||
|
||||
if host:
|
||||
set_host_prop("host", host)
|
||||
|
||||
211
apiserver/jobs/async_urls_delete.py
Normal file
211
apiserver/jobs/async_urls_delete.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import requests
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import db
|
||||
from apiserver.database.model.url_to_delete import (
|
||||
UrlToDelete,
|
||||
DeletionStatus,
|
||||
StorageType,
|
||||
)
|
||||
|
||||
log = config.logger(f"JOB-{Path(__file__).name}")
|
||||
conf = config.get("services.async_urls_delete")
|
||||
max_retries = conf.get("max_retries", 3)
|
||||
retry_timeout = timedelta(seconds=conf.get("retry_timeout_sec", 60))
|
||||
fileserver_timeout = conf.get("fileserver.timeout_sec", 300)
|
||||
UrlPrefix = Tuple[str, str]
|
||||
|
||||
|
||||
def validate_fileserver_access(fileserver_host: str) -> str:
|
||||
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
|
||||
if not fileserver_host:
|
||||
log.error(f"Fileserver host not configured")
|
||||
exit(1)
|
||||
|
||||
res = requests.get(url=fileserver_host)
|
||||
res.raise_for_status()
|
||||
|
||||
return fileserver_host
|
||||
|
||||
|
||||
def mark_retry_failed(ids: Sequence[str], reason: str):
|
||||
UrlToDelete.objects(id__in=ids).update(
|
||||
last_failure_time=datetime.utcnow(),
|
||||
last_failure_reason=reason,
|
||||
inc__retry_count=1,
|
||||
)
|
||||
UrlToDelete.objects(id__in=ids, retry_count__gte=max_retries).update(
|
||||
status=DeletionStatus.failed
|
||||
)
|
||||
|
||||
|
||||
def mark_failed(query: Q, reason: str):
|
||||
UrlToDelete.objects(query).update(
|
||||
status=DeletionStatus.failed,
|
||||
last_failure_time=datetime.utcnow(),
|
||||
last_failure_reason=reason,
|
||||
)
|
||||
|
||||
|
||||
def delete_fileserver_urls(
|
||||
urls_query: Q, fileserver_host: str, url_prefixes: Sequence[UrlPrefix]
|
||||
):
|
||||
to_delete = list(UrlToDelete.objects(urls_query).limit(10000))
|
||||
if not to_delete:
|
||||
return
|
||||
|
||||
def resolve_path(url_: UrlToDelete) -> str:
|
||||
parsed = furl(url_.url)
|
||||
url_host = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme else None
|
||||
url_path = str(parsed.path)
|
||||
|
||||
for host, path_prefix in url_prefixes:
|
||||
if host and url_host != host:
|
||||
continue
|
||||
if path_prefix and not url_path.startswith(path_prefix + "/"):
|
||||
continue
|
||||
return url_path[len(path_prefix or ""):]
|
||||
|
||||
raise ValueError("could not map path")
|
||||
|
||||
paths = set()
|
||||
path_to_id_mapping = defaultdict(list)
|
||||
for url in to_delete:
|
||||
try:
|
||||
path = resolve_path(url)
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
raise ValueError("Empty path")
|
||||
except Exception as ex:
|
||||
err = str(ex)
|
||||
log.warn(f"Error getting path for {url.url}: {err}")
|
||||
mark_failed(Q(id=url.id), err)
|
||||
continue
|
||||
|
||||
paths.add(path)
|
||||
path_to_id_mapping[path].append(url.id)
|
||||
|
||||
if not paths:
|
||||
return
|
||||
|
||||
ids_to_delete = set(chain.from_iterable(path_to_id_mapping.values()))
|
||||
try:
|
||||
res = requests.post(
|
||||
url=furl(fileserver_host).add(path="delete_many").url,
|
||||
json={"files": list(paths)},
|
||||
timeout=fileserver_timeout,
|
||||
)
|
||||
res.raise_for_status()
|
||||
except Exception as ex:
|
||||
err = str(ex)
|
||||
log.warn(f"Error deleting {len(paths)} files from fileserver: {err}")
|
||||
mark_retry_failed(list(ids_to_delete), err)
|
||||
return
|
||||
|
||||
res_data = res.json()
|
||||
deleted_ids = set(
|
||||
chain.from_iterable(
|
||||
path_to_id_mapping.get(path, [])
|
||||
for path in list(res_data.get("deleted", {}))
|
||||
)
|
||||
)
|
||||
if deleted_ids:
|
||||
UrlToDelete.objects(id__in=list(deleted_ids)).delete()
|
||||
log.info(f"{len(deleted_ids)} files deleted from the fileserver")
|
||||
|
||||
failed_ids = set()
|
||||
for err, error_ids in res_data.get("errors", {}).items():
|
||||
error_ids = list(
|
||||
chain.from_iterable(path_to_id_mapping.get(path, []) for path in error_ids)
|
||||
)
|
||||
mark_retry_failed(error_ids, err)
|
||||
log.warning(
|
||||
f"Failed to delete {len(error_ids)} files from the fileserver due to: {err}"
|
||||
)
|
||||
failed_ids.update(error_ids)
|
||||
|
||||
missing_ids = ids_to_delete - deleted_ids - failed_ids
|
||||
if missing_ids:
|
||||
mark_retry_failed(list(missing_ids), "Not succeeded")
|
||||
|
||||
|
||||
def _get_fileserver_url_prefixes(fileserver_host: str) -> Sequence[UrlPrefix]:
|
||||
def _parse_url_prefix(prefix) -> UrlPrefix:
|
||||
url = furl(prefix)
|
||||
host = f"{url.scheme}://{url.netloc}" if url.scheme else None
|
||||
return host, str(url.path).rstrip("/")
|
||||
|
||||
url_prefixes = [
|
||||
_parse_url_prefix(p) for p in conf.get("fileserver.url_prefixes", [])
|
||||
]
|
||||
if not any(fileserver_host == host for host, _ in url_prefixes):
|
||||
url_prefixes.append((fileserver_host, ""))
|
||||
|
||||
return url_prefixes
|
||||
|
||||
|
||||
def run_delete_loop(fileserver_host: str):
|
||||
fileserver_host = validate_fileserver_access(fileserver_host)
|
||||
|
||||
storage_delete_funcs = {
|
||||
StorageType.fileserver: partial(
|
||||
delete_fileserver_urls,
|
||||
fileserver_host=fileserver_host,
|
||||
url_prefixes=_get_fileserver_url_prefixes(fileserver_host),
|
||||
),
|
||||
}
|
||||
while True:
|
||||
now = datetime.utcnow()
|
||||
urls_query = (
|
||||
Q(status__ne=DeletionStatus.failed)
|
||||
& Q(retry_count__lt=max_retries)
|
||||
& (
|
||||
Q(last_failure_time__exists=False)
|
||||
| Q(last_failure_time__lt=now - retry_timeout)
|
||||
)
|
||||
)
|
||||
|
||||
url_to_delete: UrlToDelete = UrlToDelete.objects(
|
||||
urls_query & Q(storage_type__in=list(storage_delete_funcs))
|
||||
).order_by("retry_count").limit(1).first()
|
||||
if not url_to_delete:
|
||||
sleep(10)
|
||||
continue
|
||||
|
||||
company = url_to_delete.company
|
||||
user = url_to_delete.user
|
||||
storage_type = url_to_delete.storage_type
|
||||
log.info(
|
||||
f"Deleting {storage_type} objects for company: {company}, user: {user}"
|
||||
)
|
||||
company_storage_urls_query = urls_query & Q(
|
||||
company=company, storage_type=storage_type,
|
||||
)
|
||||
storage_delete_funcs[storage_type](urls_query=company_storage_urls_query)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
|
||||
parser.add_argument(
|
||||
"--fileserver-host", "-fh", help="Fileserver host address", type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
db.initialize()
|
||||
run_delete_loop(args.fileserver_host)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,10 @@
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
from inspect import signature
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pymongo.database
|
||||
from mongoengine.connection import get_db
|
||||
from packaging.version import Version, parse
|
||||
|
||||
@@ -80,8 +82,15 @@ def _apply_migrations(log: Logger):
|
||||
if not func:
|
||||
continue
|
||||
try:
|
||||
sig = signature(func)
|
||||
kwargs = {}
|
||||
if len(sig.parameters) == 2:
|
||||
name, param = list(sig.parameters.items())[-1]
|
||||
key = name.replace("_", "-")
|
||||
if issubclass(param.annotation, pymongo.database.Database) and key in dbs:
|
||||
kwargs[name] = get_db(key)
|
||||
log.info(f"Applying {script.stem}/{func_name}()")
|
||||
func(get_db(alias))
|
||||
func(get_db(alias), **kwargs)
|
||||
except Exception:
|
||||
log.exception(f"Failed applying {script}:{func_name}()")
|
||||
raise ValueError(
|
||||
|
||||
@@ -21,8 +21,10 @@ from typing import (
|
||||
Union,
|
||||
Mapping,
|
||||
IO,
|
||||
Callable,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from uuid import uuid4, UUID, uuid5
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import mongoengine
|
||||
@@ -54,6 +56,7 @@ from apiserver.database.model.task.task import (
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@@ -69,6 +72,8 @@ class PrePopulate:
|
||||
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
_name_guid_ns = UUID("bda3acc1-e612-506c-bade-80071b6cf039")
|
||||
_example_id_prefix = "e-"
|
||||
task_cls: Type[Task]
|
||||
project_cls: Type[Project]
|
||||
model_cls: Type[Model]
|
||||
@@ -298,8 +303,9 @@ class PrePopulate:
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
cls.user_cls(id=user_id, name=user_name, company="").save()
|
||||
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
|
||||
if not existing_user:
|
||||
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
@@ -687,6 +693,58 @@ class PrePopulate:
|
||||
continue
|
||||
yield clean
|
||||
|
||||
@staticmethod
|
||||
def _new_id(_):
|
||||
return str(uuid4()).replace("-", "")
|
||||
|
||||
@classmethod
|
||||
def _hash_id(cls, name: str):
|
||||
return str(uuid5(cls._name_guid_ns, name)).replace("-", "")
|
||||
|
||||
@classmethod
|
||||
def _example_id(cls, orig_id: str):
|
||||
if not orig_id or orig_id.startswith(cls._example_id_prefix):
|
||||
return orig_id
|
||||
|
||||
return cls._example_id_prefix + orig_id
|
||||
|
||||
@classmethod
|
||||
def _private_id(cls, orig_id: str):
|
||||
if not orig_id or not orig_id.startswith(cls._example_id_prefix):
|
||||
return orig_id
|
||||
|
||||
return orig_id[len(cls._example_id_prefix) :]
|
||||
|
||||
@classmethod
|
||||
def _generate_new_ids(
|
||||
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
|
||||
) -> Mapping[str, str]:
|
||||
if not metadata or not any(
|
||||
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
|
||||
):
|
||||
return {}
|
||||
|
||||
ids = {}
|
||||
for entity_file in entity_files:
|
||||
with reader.open(entity_file) as f:
|
||||
is_project = splitext(entity_file.orig_filename)[0].endswith(".Project")
|
||||
if metadata.get("new_ids"):
|
||||
id_func = cls._new_id
|
||||
elif metadata.get("example_ids"):
|
||||
id_func = cls._example_id if not is_project else cls._hash_id
|
||||
elif metadata.get("private_ids"):
|
||||
id_func = cls._private_id if not is_project else cls._new_id
|
||||
for item in cls.json_lines(f):
|
||||
doc = json.loads(item)
|
||||
orig_id = doc.get("_id")
|
||||
if orig_id:
|
||||
ids[orig_id] = (
|
||||
id_func(orig_id)
|
||||
if id_func != cls._hash_id
|
||||
else id_func(doc.get("name"))
|
||||
)
|
||||
return ids
|
||||
|
||||
@classmethod
|
||||
def _import(
|
||||
cls,
|
||||
@@ -701,37 +759,42 @@ class PrePopulate:
|
||||
Start from entities since event import will require the tasks already in DB
|
||||
"""
|
||||
event_file_ending = cls.events_file_suffix + ".json"
|
||||
entity_files = (
|
||||
entity_files = [
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
)
|
||||
]
|
||||
metadata = metadata or {}
|
||||
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
|
||||
tasks = []
|
||||
for entity_file in entity_files:
|
||||
with reader.open(entity_file) as f:
|
||||
full_name = splitext(entity_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
res = cls._import_entity(f, full_name, company_id, user_id, metadata)
|
||||
res = cls._import_entity(
|
||||
f, full_name, company_id, user_id, metadata, old_to_new_ids
|
||||
)
|
||||
if res:
|
||||
tasks = res
|
||||
|
||||
if sort_tasks_by_last_updated:
|
||||
tasks = sorted(tasks, key=attrgetter("last_update"))
|
||||
|
||||
new_to_old_ids = {v: k for k, v in old_to_new_ids.items()}
|
||||
for task in tasks:
|
||||
old_task_id = new_to_old_ids.get(task.id, task.id)
|
||||
events_file = first(
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if fi.orig_filename.endswith(task.id + event_file_ending)
|
||||
if fi.orig_filename.endswith(old_task_id + event_file_ending)
|
||||
)
|
||||
if not events_file:
|
||||
continue
|
||||
with reader.open(events_file) as f:
|
||||
full_name = splitext(events_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
cls._import_events(f, full_name, company_id, user_id)
|
||||
cls._import_events(f, company_id, user_id, task.id)
|
||||
|
||||
@classmethod
|
||||
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
|
||||
@@ -743,6 +806,28 @@ class PrePopulate:
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_project_data(project_data: dict) -> dict:
|
||||
if not project_data.get("basename"):
|
||||
name: str = project_data["name"]
|
||||
_, _, basename = name.rpartition("/")
|
||||
project_data["basename"] = basename
|
||||
|
||||
return project_data
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_model_data(model_data: dict) -> dict:
|
||||
metadata_key = "metadata"
|
||||
metadata = model_data.get(metadata_key)
|
||||
if isinstance(metadata, list):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in metadata
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
model_data[metadata_key] = metadata
|
||||
return model_data
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict) -> dict:
|
||||
"""
|
||||
@@ -822,14 +907,26 @@ class PrePopulate:
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
old_to_new_ids: Mapping[str, str] = None,
|
||||
) -> Optional[Sequence[Task]]:
|
||||
cls_ = cls._get_entity_type(full_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
tasks = []
|
||||
override_project_count = 0
|
||||
data_upgrade_funcs: Mapping[Type, Callable] = {
|
||||
cls.task_cls: cls._upgrade_task_data,
|
||||
cls.model_cls: cls._upgrade_model_data,
|
||||
cls.project_cls: cls._upgrade_project_data,
|
||||
}
|
||||
for item in cls.json_lines(f):
|
||||
if cls_ == cls.task_cls:
|
||||
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
|
||||
if old_to_new_ids:
|
||||
for old_id, new_id in old_to_new_ids.items():
|
||||
# replace ids only when they are standalone strings
|
||||
# otherwise artifacts uris that contain old ids may get damaged
|
||||
item = item.replace(f'"{old_id}"', f'"{new_id}"')
|
||||
upgrade_func = data_upgrade_funcs.get(cls_)
|
||||
if upgrade_func:
|
||||
item = json.dumps(upgrade_func(json.loads(item)))
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
@@ -863,11 +960,13 @@ class PrePopulate:
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
|
||||
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
|
||||
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
for ev in events:
|
||||
ev["task"] = task_id
|
||||
ev["company_id"] = company_id
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
company_id, events=events, worker="", allow_locked=True
|
||||
)
|
||||
|
||||
@@ -53,6 +53,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
name=user_name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
created=datetime.utcnow(),
|
||||
).save()
|
||||
|
||||
return user_id
|
||||
|
||||
@@ -97,22 +97,6 @@ def _migrate_model_labels(db: Database):
|
||||
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||
|
||||
|
||||
def _migrate_project_description(db: Database):
|
||||
projects: Collection = db["project"]
|
||||
filter = {
|
||||
"$or": [
|
||||
{
|
||||
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
|
||||
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
|
||||
},
|
||||
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
|
||||
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
|
||||
]
|
||||
}
|
||||
for doc in projects.find(filter=filter):
|
||||
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
|
||||
|
||||
|
||||
def _migrate_project_names(db: Database):
|
||||
projects: Collection = db["project"]
|
||||
|
||||
@@ -141,5 +125,4 @@ def migrate_backend(db: Database):
|
||||
_migrate_docker_cmd(db)
|
||||
_migrate_model_labels(db)
|
||||
_migrate_project_names(db)
|
||||
_migrate_project_description(db)
|
||||
_drop_all_indices_from_collections(db, ["task*"])
|
||||
|
||||
22
apiserver/mongo/migrations/1_0_2.py
Normal file
22
apiserver/mongo/migrations/1_0_2.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
|
||||
def _migrate_project_description(db: Database):
|
||||
projects: Collection = db["project"]
|
||||
filter = {
|
||||
"$or": [
|
||||
{
|
||||
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
|
||||
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
|
||||
},
|
||||
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
|
||||
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
|
||||
]
|
||||
}
|
||||
for doc in projects.find(filter=filter):
|
||||
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
_migrate_project_description(db)
|
||||
29
apiserver/mongo/migrations/1_3_0.py
Normal file
29
apiserver/mongo/migrations/1_3_0.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def _convert_metadata(db: Database, name):
|
||||
collection: Collection = db[name]
|
||||
|
||||
metadata_field = "metadata"
|
||||
query = {metadata_field: {"$exists": True, "$type": 4}}
|
||||
for doc in collection.find(filter=query, projection=(metadata_field,)):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in doc.get(metadata_field, [])
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
|
||||
)
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collections = ["model", "queue"]
|
||||
for name in collections:
|
||||
_convert_metadata(db, name)
|
||||
|
||||
_drop_all_indices_from_collections(db, collections)
|
||||
12
apiserver/mongo/migrations/1_6_0.py
Normal file
12
apiserver/mongo/migrations/1_6_0.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
projects: Collection = db["project"]
|
||||
for doc in projects.find({"basename": None}):
|
||||
name: str = doc["name"]
|
||||
_, _, basename = name.rpartition("/")
|
||||
projects.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {"basename": basename}},
|
||||
)
|
||||
15
apiserver/mongo/migrations/1_7_0.py
Normal file
15
apiserver/mongo/migrations/1_7_0.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
|
||||
def migrate_backend(db: Database, auth_db: Database):
|
||||
users: Collection = db["user"]
|
||||
auth_users: Collection = auth_db["user"]
|
||||
created_field = "created"
|
||||
for doc in users.find({created_field: {"$exists": False}}):
|
||||
auth_user = auth_users.find_one({"_id": doc["_id"]}, projection=[created_field])
|
||||
if not auth_user or created_field not in auth_user:
|
||||
continue
|
||||
users.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {created_field: auth_user[created_field]}}
|
||||
)
|
||||
@@ -1,10 +1,8 @@
|
||||
import threading
|
||||
from os import getenv
|
||||
from time import sleep
|
||||
|
||||
from boltons.iterutils import first
|
||||
from redis import StrictRedis
|
||||
from redis.sentinel import Sentinel, SentinelConnectionPool
|
||||
from rediscluster import RedisCluster
|
||||
|
||||
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
||||
from apiserver.config_repo import config
|
||||
@@ -21,6 +19,11 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"TRAINS_REDIS_SERVICE_PORT",
|
||||
"REDIS_SERVICE_PORT",
|
||||
)
|
||||
OVERRIDE_PASSWORD_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_PASSWORD",
|
||||
"TRAINS_REDIS_SERVICE_PASSWORD",
|
||||
"REDIS_SERVICE_PASSWORD",
|
||||
)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
@@ -30,99 +33,7 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override redis port {OVERRIDE_PORT}")
|
||||
|
||||
|
||||
class MyPubSubWorkerThread(threading.Thread):
|
||||
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
|
||||
super(MyPubSubWorkerThread, self).__init__()
|
||||
self.daemon = daemon
|
||||
self.sentinel = sentinel
|
||||
self.on_new_master = on_new_master
|
||||
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.msg_sleep_time = msg_sleep_time
|
||||
self._running = False
|
||||
self.pubsub = None
|
||||
|
||||
def subscribe(self):
|
||||
if self.pubsub:
|
||||
try:
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self.pubsub = None
|
||||
|
||||
subscriptions = {"+switch-master": self.on_new_master}
|
||||
|
||||
while not self.pubsub or not self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub = self.sentinel.pubsub()
|
||||
self.pubsub.subscribe(**subscriptions)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
|
||||
)
|
||||
sleep(3)
|
||||
log.info(f"Subscribed to sentinel {self.sentinel_host}")
|
||||
|
||||
def run(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
|
||||
self.subscribe()
|
||||
|
||||
while self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
|
||||
)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
|
||||
)
|
||||
self.subscribe()
|
||||
|
||||
self.pubsub.close()
|
||||
self._running = False
|
||||
|
||||
def stop(self):
|
||||
# stopping simply unsubscribes from all channels and patterns.
|
||||
# the unsubscribe responses that are generated will short circuit
|
||||
# the loop in run(), calling pubsub.close() to clean up the connection
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
|
||||
|
||||
# todo,future - multi master clusters?
|
||||
class RedisCluster(object):
|
||||
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
|
||||
self.service_name = service_name
|
||||
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
|
||||
self.master = None
|
||||
self.master_host_port = None
|
||||
self.reconfigure()
|
||||
self.sentinel_threads = {}
|
||||
self.listen()
|
||||
|
||||
def reconfigure(self):
|
||||
try:
|
||||
self.master_host_port = self.sentinel.discover_master(self.service_name)
|
||||
self.master = self.sentinel.master_for(self.service_name)
|
||||
log.info(f"Reconfigured master to {self.master_host_port}")
|
||||
except Exception as ex:
|
||||
log.error(f"Error while reconfiguring. {ex.args[0]}")
|
||||
|
||||
def listen(self):
|
||||
def on_new_master(workerThread):
|
||||
self.reconfigure()
|
||||
|
||||
for sentinel in self.sentinel.sentinels:
|
||||
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
|
||||
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
|
||||
)
|
||||
self.sentinel_threads[sentinel_host].start()
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
|
||||
|
||||
class RedisManager(object):
|
||||
@@ -131,6 +42,9 @@ class RedisManager(object):
|
||||
for alias, alias_config in redis_config_dict.items():
|
||||
|
||||
alias_config = alias_config.as_plain_ordered_dict()
|
||||
alias_config["password"] = config.get(
|
||||
f"secure.redis.{alias}.password", None
|
||||
)
|
||||
|
||||
is_cluster = alias_config.get("cluster", False)
|
||||
|
||||
@@ -142,34 +56,19 @@ class RedisManager(object):
|
||||
if port:
|
||||
alias_config["port"] = port
|
||||
|
||||
db = alias_config.get("db", 0)
|
||||
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
|
||||
if password:
|
||||
alias_config["password"] = password
|
||||
|
||||
sentinels = alias_config.get("sentinels", None)
|
||||
service_name = alias_config.get("service_name", None)
|
||||
|
||||
if not is_cluster and sentinels:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. mixed regular and cluster mode",
|
||||
alias=alias,
|
||||
)
|
||||
if is_cluster and (not sentinels or not service_name):
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing sentinels or service_name",
|
||||
alias=alias,
|
||||
)
|
||||
if not is_cluster and (not port or not host):
|
||||
if not port or not host:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing port or host", alias=alias
|
||||
)
|
||||
|
||||
if is_cluster:
|
||||
# todo support all redis connection args via sentinel's connection_kwargs
|
||||
del alias_config["sentinels"]
|
||||
del alias_config["cluster"]
|
||||
del alias_config["service_name"]
|
||||
self.aliases[alias] = RedisCluster(
|
||||
sentinels, service_name, **alias_config
|
||||
)
|
||||
del alias_config["db"]
|
||||
self.aliases[alias] = RedisCluster(**alias_config)
|
||||
else:
|
||||
self.aliases[alias] = StrictRedis(**alias_config)
|
||||
|
||||
@@ -177,27 +76,21 @@ class RedisManager(object):
|
||||
obj = self.aliases.get(alias)
|
||||
if not obj:
|
||||
raise GeneralError(f"Invalid Redis alias {alias}")
|
||||
if isinstance(obj, RedisCluster):
|
||||
obj.master.get("health")
|
||||
return obj.master
|
||||
else:
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
def host(self, alias):
|
||||
r = self.connection(alias)
|
||||
pool = r.connection_pool
|
||||
if isinstance(pool, SentinelConnectionPool):
|
||||
connections = pool.connection_kwargs[
|
||||
"connection_pool"
|
||||
]._available_connections
|
||||
if isinstance(r, RedisCluster):
|
||||
connections = first(r.connection_pool._available_connections.values())
|
||||
else:
|
||||
connections = pool._available_connections
|
||||
connections = r.connection_pool._available_connections
|
||||
|
||||
if len(connections) > 0:
|
||||
return connections[0].host
|
||||
else:
|
||||
if not connections:
|
||||
return None
|
||||
|
||||
return connections[0].host
|
||||
|
||||
|
||||
redman = RedisManager(config.get("hosts.redis"))
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
attrs>=19.1.0
|
||||
attrs>=22.1.0
|
||||
bcrypt>=3.1.4
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
elasticsearch==7.13.3
|
||||
fastjsonschema>=2.8
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
@@ -16,16 +16,16 @@ jinja2==2.11.3
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.10.0
|
||||
mongoengine==0.19.1
|
||||
mongoengine==0.23.1
|
||||
nested_dict>=1.61
|
||||
packaging==20.3
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt<2.0.0
|
||||
pymongo==3.10.1
|
||||
pyjwt>=2.4.0
|
||||
pymongo[srv]==3.12.0
|
||||
python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
related>=0.7.2
|
||||
redis==3.5.3
|
||||
redis-py-cluster>=2.1.3
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
|
||||
@@ -26,6 +26,10 @@ credentials {
|
||||
type: string
|
||||
description: Credentials secret key
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
batch_operation {
|
||||
|
||||
@@ -15,11 +15,19 @@ _definitions {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
last_used {
|
||||
type: string
|
||||
description: ""
|
||||
format: "date-time"
|
||||
}
|
||||
last_used_from {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
role {
|
||||
@@ -222,6 +230,12 @@ create_credentials {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.17": ${create_credentials."2.1"} {
|
||||
request.properties.label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_credentials {
|
||||
@@ -248,6 +262,38 @@ get_credentials {
|
||||
}
|
||||
}
|
||||
|
||||
edit_credentials {
|
||||
allow_roles = [ "*" ]
|
||||
internal: false
|
||||
"2.19" {
|
||||
description: """Updates the label of the existing credentials for the authenticated user."""
|
||||
request {
|
||||
type: object
|
||||
required: [ access_key ]
|
||||
properties {
|
||||
access_key {
|
||||
type: string
|
||||
description: Existing credentials key
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: New credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of credentials updated"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
revoke_credentials {
|
||||
allow_roles = [ "*" ]
|
||||
internal: false
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -61,14 +61,14 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||
@@ -98,9 +98,21 @@ _definitions {
|
||||
additionalProperties: true
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
stats {
|
||||
description: "Model statistics"
|
||||
type: object
|
||||
properties {
|
||||
labels_count {
|
||||
description: Number of the model labels
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -199,6 +211,36 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.20": ${get_all_ex."2.15"} {
|
||||
request.properties.include_stats {
|
||||
description: "If true, include models statistic in response"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -278,6 +320,14 @@ get_all {
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
last_update {
|
||||
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
pattern: "^(>=|>|<=|<)?.*$"
|
||||
}
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
@@ -302,6 +352,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_frameworks {
|
||||
"2.8" {
|
||||
@@ -361,7 +434,7 @@ update_for_task {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
||||
@@ -427,7 +500,7 @@ create {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@@ -483,9 +556,11 @@ create {
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -522,7 +597,7 @@ edit {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@@ -578,9 +653,11 @@ edit {
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -611,7 +688,7 @@ update {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
@@ -661,9 +738,11 @@ update {
|
||||
}
|
||||
"2.13": ${update."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -672,7 +751,7 @@ publish_many {
|
||||
description: Publish models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to publish"
|
||||
ids.description: "IDs of the models to publish"
|
||||
force_publish_task {
|
||||
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
||||
type: boolean
|
||||
@@ -733,7 +812,7 @@ archive_many {
|
||||
description: Archive models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to archive"
|
||||
ids.description: "IDs of the models to archive"
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -769,10 +848,9 @@ delete_many {
|
||||
description: Delete models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to delete"
|
||||
ids.description: "IDs of the models to delete"
|
||||
force {
|
||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
||||
"""
|
||||
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
@@ -929,6 +1007,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -102,4 +102,78 @@ get_user_companies {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_entities_count {
|
||||
"2.20": {
|
||||
description: "Get counts for the company entities according to the passed search criteria"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
description: Search criteria for projects
|
||||
}
|
||||
tasks {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
description: Search criteria for experiments
|
||||
}
|
||||
models {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
description: Search criteria for models
|
||||
}
|
||||
pipelines {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
description: Search criteria for pipelines
|
||||
}
|
||||
datasets {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
description: Search criteria for datasets
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
type: integer
|
||||
description: The number of projects matching the criteria
|
||||
}
|
||||
tasks {
|
||||
type: integer
|
||||
description: The number of experiments matching the criteria
|
||||
}
|
||||
models {
|
||||
type: integer
|
||||
description: The number of models matching the criteria
|
||||
}
|
||||
pipelines {
|
||||
type: integer
|
||||
description: The number of pipelines matching the criteria
|
||||
}
|
||||
datasets {
|
||||
type: integer
|
||||
description: The number of datasets matching the criteria
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_entities_count."2.20"} {
|
||||
request.properties {
|
||||
search_hidden {
|
||||
description: "If set to 'true' then hidden projects and tasks are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
active_users {
|
||||
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
47
apiserver/schema/services/pipelines.conf
Normal file
47
apiserver/schema/services/pipelines.conf
Normal file
@@ -0,0 +1,47 @@
|
||||
_description: "Provides a management API for pipelines in the system."
|
||||
_definitions {
|
||||
}
|
||||
|
||||
start_pipeline {
|
||||
"2.17" {
|
||||
description: "Start a pipeline"
|
||||
request {
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task on which the pipeline will be based"
|
||||
type: string
|
||||
}
|
||||
queue {
|
||||
description: "Queue ID in which the created pipeline task will be enqueued"
|
||||
type: string
|
||||
}
|
||||
args {
|
||||
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
name: { type: string }
|
||||
value: { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
pipeline {
|
||||
description: "ID of the new pipeline task"
|
||||
type: string
|
||||
}
|
||||
enqueued {
|
||||
description: "True if the task was successfuly enqueued"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,10 @@ _definitions {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
basename {
|
||||
description: "Project base name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
@@ -43,14 +47,14 @@ _definitions {
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -70,6 +74,18 @@ _definitions {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks_24h {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
last_task_run {
|
||||
description: "The most recent started time of a task"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
@@ -78,6 +94,10 @@ _definitions {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
completed {
|
||||
description: "Number of 'completed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
@@ -135,6 +155,10 @@ _definitions {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
basename {
|
||||
description: "Project base name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
@@ -153,19 +177,24 @@ _definitions {
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
last_update {
|
||||
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
// extra properties
|
||||
stats {
|
||||
description: "Additional project stats"
|
||||
@@ -188,6 +217,28 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
own_tasks {
|
||||
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
type: integer
|
||||
}
|
||||
own_models {
|
||||
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
type: integer
|
||||
}
|
||||
dataset_stats {
|
||||
description: Project dataset statistics
|
||||
type: object
|
||||
properties {
|
||||
file_count {
|
||||
type: integer
|
||||
description: The number of files stored in the dataset
|
||||
}
|
||||
total_size {
|
||||
type: integer
|
||||
description: The total dataset size in bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
metric_variant_result {
|
||||
@@ -294,14 +345,14 @@ create {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -359,6 +410,10 @@ get_all {
|
||||
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
basename {
|
||||
description: "Project base name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
@@ -379,7 +434,7 @@ get_all {
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of dataviews"
|
||||
description: "Page number, returns a specific page out of the resulting list of projects"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
@@ -414,7 +469,6 @@ get_all {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -430,6 +484,36 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${get_all."2.13"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden projects are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.14"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
@@ -469,24 +553,63 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
check_own_contents {
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks and models are counted"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
}
|
||||
"2.14": ${get_all_ex."2.13"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden projects are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.14"} {
|
||||
request {
|
||||
properties {
|
||||
own_tasks {
|
||||
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
type: integer
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
own_models {
|
||||
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.16": ${get_all_ex."2.15"} {
|
||||
request.properties.stats_with_children {
|
||||
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"2.17": ${get_all_ex."2.16"} {
|
||||
request.properties.include_stats_filter {
|
||||
description: The filter for selecting entities that participate in statistics calculation. For each task field that you want to filter on pass the list of allowed values. Prepend the value with '-' to exclude
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
"2.20": ${get_all_ex."2.17"} {
|
||||
request.properties.include_dataset_stats {
|
||||
description: "If true, include project dataset statistic in response"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
@@ -504,23 +627,19 @@ update {
|
||||
description: "Project name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -594,7 +713,7 @@ merge {
|
||||
type: object
|
||||
properties {
|
||||
moved_entities {
|
||||
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
|
||||
description: "The number of tasks and models moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
moved_projects {
|
||||
@@ -605,6 +724,42 @@ merge {
|
||||
}
|
||||
}
|
||||
}
|
||||
validate_delete {
|
||||
"2.14" {
|
||||
description: "Validates that the project existis and can be deleted"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tasks {
|
||||
description: "The total number of tasks under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
non_archived_tasks {
|
||||
description: "The total number of non-archived tasks under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
models {
|
||||
description: "The total number of models under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
non_archived_models {
|
||||
description: "The total number of non-archived models under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
@@ -613,7 +768,7 @@ delete {
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
@@ -622,7 +777,6 @@ delete {
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -755,6 +909,7 @@ get_hyper_parameters {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
@@ -803,7 +958,105 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_values {
|
||||
"2.17" {
|
||||
description: """Get a list of distinct values for the chosen model metadata key"""
|
||||
request {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
projects {
|
||||
description: "Project IDs"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
key {
|
||||
description: "Metadata key"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
total {
|
||||
description: "Total number of distinct values"
|
||||
type: integer
|
||||
}
|
||||
values {
|
||||
description: "The list of the unique values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_keys {
|
||||
"2.17" {
|
||||
description: """Get a list of all metadata keys used in models within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
|
||||
page {
|
||||
description: "Page number"
|
||||
default: 0
|
||||
type: integer
|
||||
}
|
||||
page_size {
|
||||
description: "Page size"
|
||||
default: 500
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
keys {
|
||||
description: "A list of model keys"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
type: integer
|
||||
}
|
||||
total {
|
||||
description: "Total number of results"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_project_tags {
|
||||
"2.17" {
|
||||
description: "Get user and system tags used for the specified projects and their children"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
@@ -811,7 +1064,6 @@ get_task_tags {
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
@@ -933,4 +1185,4 @@ get_task_parents {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,9 +79,11 @@ _definitions {
|
||||
items { "$ref": "#/definitions/entry" }
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Queue metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -110,11 +112,53 @@ get_by_id {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.20": ${get_by_id."2.4"} {
|
||||
request.properties.max_task_entries {
|
||||
description: Max number of queue task entries to return
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
// typescript generation hack
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.4": ${get_all."2.4"}
|
||||
"2.15": ${get_all_ex."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.20": ${get_all_ex."2.15"} {
|
||||
request.properties.max_task_entries {
|
||||
description: Max number of queue task entries to return
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
"2.21": ${get_all_ex."2.20"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden queues are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
@@ -178,6 +222,42 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.20": ${get_all."2.15"} {
|
||||
request.properties.max_task_entries {
|
||||
description: Max number of queue task entries to return
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
"2.21": ${get_all."2.20"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden queues are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_default {
|
||||
"2.4" {
|
||||
@@ -235,6 +315,15 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.4" {
|
||||
@@ -276,7 +365,15 @@ update {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${update."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,6 +463,33 @@ get_next_task {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${get_next_task."2.4"} {
|
||||
request.properties.get_task_info {
|
||||
description: "If set then additional task info is returned"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.task_info {
|
||||
description: "Info about the returned task. Returned only if get_task_info is set to True"
|
||||
type: object
|
||||
properties {
|
||||
company {
|
||||
description: Task company ID
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: ID of the user who created the task
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.21": ${get_next_task."2.14"} {
|
||||
request.properties.task {
|
||||
description: Task company ID
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
remove_task {
|
||||
"2.4" {
|
||||
@@ -569,6 +693,13 @@ get_queue_metrics : {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.20": ${get_queue_metrics."2.4"} {
|
||||
request.properties.refresh {
|
||||
type: boolean
|
||||
default: false
|
||||
description: If set then the new queue metrics is taken
|
||||
}
|
||||
}
|
||||
}
|
||||
add_or_update_metadata {
|
||||
"2.13" {
|
||||
@@ -586,6 +717,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -630,3 +766,51 @@ delete_metadata {
|
||||
}
|
||||
}
|
||||
}
|
||||
peek_task {
|
||||
"2.15" {
|
||||
description: "Peek the next task from a given queue"
|
||||
request {
|
||||
type: object
|
||||
required: [ queue ]
|
||||
properties {
|
||||
queue {
|
||||
description: "ID of the queue"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_num_entries {
|
||||
"2.15" {
|
||||
description: "Get the number of task entries in the given queue"
|
||||
request {
|
||||
type: object
|
||||
required: [ queue ]
|
||||
properties {
|
||||
queue {
|
||||
description: "ID of the queue"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
num {
|
||||
description: "Number of entries"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user