mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
421 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
702b6dc9c8 | ||
|
|
db15f235e4 | ||
|
|
8c347f8fa9 | ||
|
|
768c3d80ff | ||
|
|
a5c3ef6385 | ||
|
|
11b7a384af | ||
|
|
9a70ade4a6 | ||
|
|
91ce140901 | ||
|
|
49084a9c49 | ||
|
|
8a99eb6812 | ||
|
|
811ab2bf4f | ||
|
|
3752db122b | ||
|
|
439911b84c | ||
|
|
262a301e28 | ||
|
|
a604451b01 | ||
|
|
88a7773621 | ||
|
|
35c4061992 | ||
|
|
4684fd5b74 | ||
|
|
e08123fcc0 | ||
|
|
e713e876eb | ||
|
|
c2cc788319 | ||
|
|
da8315d0db | ||
|
|
4ac6f88278 | ||
|
|
a7865ccbec | ||
|
|
ec14f327c6 | ||
|
|
a03b24d6b6 | ||
|
|
cb71ef8e47 | ||
|
|
8678fbc995 | ||
|
|
58df8f201a | ||
|
|
f4bf16c156 | ||
|
|
942f996237 | ||
|
|
c1e7f8f9c1 | ||
|
|
274c487b37 | ||
|
|
cc0129a800 | ||
|
|
388dd1b01f | ||
|
|
d62ecb5e6e | ||
|
|
6d507616b3 | ||
|
|
d0252a6dd9 | ||
|
|
2263e7cc1e | ||
|
|
81b93e6811 | ||
|
|
491e83d0f1 | ||
|
|
f84cc0a2cb | ||
|
|
6c5f966ed4 | ||
|
|
4eff657810 | ||
|
|
74acaa31df | ||
|
|
21ed8559bf | ||
|
|
3927604648 | ||
|
|
f7dcbd96ec | ||
|
|
5950b81f0b | ||
|
|
1e51e2e221 | ||
|
|
4c98b87554 | ||
|
|
c196043d2a | ||
|
|
752020c66a | ||
|
|
6885d07462 | ||
|
|
00552da1b0 | ||
|
|
eebe2eeffc | ||
|
|
bc2fe28bdd | ||
|
|
ed86750b24 | ||
|
|
6df69afb25 | ||
|
|
3f22423c3f | ||
|
|
3ad636c468 | ||
|
|
5c80336aa9 | ||
|
|
5cd59ea6e3 | ||
|
|
5d3ba4fa73 | ||
|
|
42556c8dbb | ||
|
|
dbe1c6f00f | ||
|
|
a17485b1bd | ||
|
|
a2b9fed92d | ||
|
|
ff34da3c88 | ||
|
|
5239755066 | ||
|
|
8061dfedbb | ||
|
|
011164ce9b | ||
|
|
8135cf5258 | ||
|
|
a83a932e84 | ||
|
|
db021f2863 | ||
|
|
1b650b1689 | ||
|
|
14d18a7aba | ||
|
|
a7ed46979f | ||
|
|
452f606889 | ||
|
|
fc47ccbf09 | ||
|
|
0206811342 | ||
|
|
a3ac1049a3 | ||
|
|
8488f63a3a | ||
|
|
9206a7c57d | ||
|
|
0c37ced2a1 | ||
|
|
b22f26129e | ||
|
|
d8b998ebd8 | ||
|
|
741fa84b52 | ||
|
|
d9579891c8 | ||
|
|
900414d0de | ||
|
|
5449b332d2 | ||
|
|
875f4b9536 | ||
|
|
95b8f22899 | ||
|
|
4058fb9ce5 | ||
|
|
cf8e847ed3 | ||
|
|
755cc803d9 | ||
|
|
3729afe014 | ||
|
|
dff2ed34e8 | ||
|
|
de9651d761 | ||
|
|
818496236b | ||
|
|
e99817b28b | ||
|
|
58465fbc17 | ||
|
|
2e4e060a82 | ||
|
|
5c5d9b6434 | ||
|
|
4291ad682a | ||
|
|
4c22757002 | ||
|
|
6e777e80b8 | ||
|
|
c8e4d9eeac | ||
|
|
b51aa5c29b | ||
|
|
e7c9daa42b | ||
|
|
7357654249 | ||
|
|
a6f671b46a | ||
|
|
17a8b440bd | ||
|
|
eb2b9cbd9a | ||
|
|
797e503e67 | ||
|
|
30cfdac8f2 | ||
|
|
24bb87aaee | ||
|
|
dd49ba180a | ||
|
|
bda903d0d8 | ||
|
|
9739eb2d5a | ||
|
|
cfbb37238f | ||
|
|
6664c6237e | ||
|
|
74200a24bd | ||
|
|
2fb9288a6c | ||
|
|
5d014d81af | ||
|
|
3a2675abe1 | ||
|
|
f0d68b1ce9 | ||
|
|
15db9cdaef | ||
|
|
a45d47f5d7 | ||
|
|
b1a50c1370 | ||
|
|
22a2a02760 | ||
|
|
ab798e4170 | ||
|
|
f09ac672d2 | ||
|
|
2149b76f63 | ||
|
|
d96420aa67 | ||
|
|
ed6c7b7bcb | ||
|
|
a392bc0bd7 | ||
|
|
7e97ec5555 | ||
|
|
9c41124b81 | ||
|
|
14ff639bb0 | ||
|
|
e66257761a | ||
|
|
0ffde24dc2 | ||
|
|
d4fdcd9b32 | ||
|
|
18570bfccb | ||
|
|
54ce6c34c6 | ||
|
|
ae4c33fa0e | ||
|
|
c7cd949fd0 | ||
|
|
1ce4058157 | ||
|
|
7b6f24b24d | ||
|
|
d03a931d84 | ||
|
|
5cc7199661 | ||
|
|
6537e9ef69 | ||
|
|
930aaff791 | ||
|
|
1999fb2479 | ||
|
|
9db14cc31d | ||
|
|
e3cc689528 | ||
|
|
9e0adc77dd | ||
|
|
58d9a64537 | ||
|
|
d397d2ae20 | ||
|
|
2d711e1500 | ||
|
|
97992b0d9e | ||
|
|
bc23f1b0cf | ||
|
|
6b3eff1426 | ||
|
|
caaf801cd0 | ||
|
|
c23e8a90d0 | ||
|
|
fa5b28ca0e | ||
|
|
bfb55a9463 | ||
|
|
37e485e1f2 | ||
|
|
3451ff441f | ||
|
|
53c9b5525e | ||
|
|
e5230edac3 | ||
|
|
a54dd8030c | ||
|
|
482a5c34bc | ||
|
|
ee2a72c70f | ||
|
|
a0d8aaf3b9 | ||
|
|
de1f823213 | ||
|
|
0c9e2f92ee | ||
|
|
6c49e96ff0 | ||
|
|
81e3fc6577 | ||
|
|
e6dc4b7557 | ||
|
|
238a47a197 | ||
|
|
04e7076628 | ||
|
|
0531612bf4 | ||
|
|
3ae410a1e9 | ||
|
|
98ed3075dd | ||
|
|
b871bf4224 | ||
|
|
8d4c02fc3c | ||
|
|
b986980c75 | ||
|
|
a4fa567be2 | ||
|
|
ddb91f226a | ||
|
|
7772f47773 | ||
|
|
9c118d14e0 | ||
|
|
efd56e085e | ||
|
|
4dff163af4 | ||
|
|
242a78a0fe | ||
|
|
78989fea91 | ||
|
|
5de7c12062 | ||
|
|
3f79c19079 | ||
|
|
fe29743c54 | ||
|
|
d760cf5835 | ||
|
|
3695f25a5f | ||
|
|
c6f1beafdd | ||
|
|
68a54c34f3 | ||
|
|
ab495ae586 | ||
|
|
b058770af1 | ||
|
|
f7e833bf6f | ||
|
|
36b9ab0453 | ||
|
|
ec0436d0da | ||
|
|
0f6c4e75b7 | ||
|
|
a41ae112a1 | ||
|
|
c28f478ea8 | ||
|
|
c18eb99d06 | ||
|
|
3a60f00d93 | ||
|
|
ee87778548 | ||
|
|
52c0c4d438 | ||
|
|
d117a4f022 | ||
|
|
6683d2d7a9 | ||
|
|
05357fe25e | ||
|
|
adc1825843 | ||
|
|
0c15169668 | ||
|
|
123dc1dcfb | ||
|
|
b2feafac09 | ||
|
|
b41ab8c550 | ||
|
|
62d5779bd5 | ||
|
|
f8b9d9802e | ||
|
|
dd8a1503b0 | ||
|
|
cff98ae900 | ||
|
|
9b108740da | ||
|
|
08a7bc7c9f | ||
|
|
fb256d7e5b | ||
|
|
710443b078 | ||
|
|
e0cde2f7c9 | ||
|
|
60b9c8de14 | ||
|
|
ecffe26be4 | ||
|
|
2570bd9e26 | ||
|
|
174f84514a | ||
|
|
65cb8d7b43 | ||
|
|
5f8ef808a3 | ||
|
|
4941ac70e0 | ||
|
|
67cd461145 | ||
|
|
92b5fc6f9a | ||
|
|
b90165b4e4 | ||
|
|
6c2dcb5c8a | ||
|
|
3efed32934 | ||
|
|
69737308fe | ||
|
|
a6dbea808a | ||
|
|
5131b17901 | ||
|
|
5f21c3a56d | ||
|
|
2350ac64ed | ||
|
|
d146127c18 | ||
|
|
abd65e103e | ||
|
|
bf65ea7bd0 | ||
|
|
73e278a8ed | ||
|
|
d92dfbbdb7 | ||
|
|
5c1e419eb5 | ||
|
|
124684f53f | ||
|
|
455b5d6758 | ||
|
|
c04e2e498b | ||
|
|
da8a45072f | ||
|
|
e1992e2054 | ||
|
|
c17cedd93a | ||
|
|
b6ad8f8790 | ||
|
|
5acc7eebc3 | ||
|
|
941927dfcd | ||
|
|
02933a9c93 | ||
|
|
e537651f29 | ||
|
|
af09fba755 | ||
|
|
04ea9018a3 | ||
|
|
ff7e1be24f | ||
|
|
fc4fd9e61c | ||
|
|
8908c7dcf9 | ||
|
|
b9996e2c1a | ||
|
|
afdc56f37c | ||
|
|
a25cd5dae8 | ||
|
|
447adb9090 | ||
|
|
92fd98d5ad | ||
|
|
c4001b4037 | ||
|
|
970a32287a | ||
|
|
17cd48dada | ||
|
|
ea3b6e955f | ||
|
|
843450bb9b | ||
|
|
e149af58b1 | ||
|
|
604a38035b | ||
|
|
cae38a365b | ||
|
|
e334246b46 | ||
|
|
36e013b40c | ||
|
|
f20cd6536e | ||
|
|
446bd35006 | ||
|
|
a377a7e315 | ||
|
|
3d046ac282 | ||
|
|
a08fa9a0e1 | ||
|
|
5856ed2836 | ||
|
|
d295355d99 | ||
|
|
77350f6119 | ||
|
|
bc2c2ebbfd | ||
|
|
1502e02a1a | ||
|
|
d0e2313a24 | ||
|
|
d8ba1a8ea7 | ||
|
|
ca7937fc4e | ||
|
|
df89bcceef | ||
|
|
cfccbe05c1 | ||
|
|
e352a6a1e7 | ||
|
|
8a3d992aaf | ||
|
|
c37f3d8d5b | ||
|
|
a96870e092 | ||
|
|
6bf1032237 | ||
|
|
3d816c747d | ||
|
|
3f2b96266b | ||
|
|
22b16d12eb | ||
|
|
c55b6f30df | ||
|
|
b7045d3d28 | ||
|
|
e31a404885 | ||
|
|
643588b71a | ||
|
|
a64c4d264d | ||
|
|
567780e188 | ||
|
|
1bc8529d83 | ||
|
|
6b480d7e87 | ||
|
|
083fd315e9 | ||
|
|
ef20e76174 | ||
|
|
8c8910808e | ||
|
|
f6ad379310 | ||
|
|
c5d6ce3e65 | ||
|
|
694dbc31c4 | ||
|
|
6488dc54e6 | ||
|
|
158da9b480 | ||
|
|
ec2e071ab7 | ||
|
|
465e270342 | ||
|
|
6705aff56f | ||
|
|
9069cfe1da | ||
|
|
677bb3ba6d | ||
|
|
cb253cff9e | ||
|
|
39ceb5ac5c | ||
|
|
d4edeaaf1b | ||
|
|
56aea1ffb8 | ||
|
|
09ab2af34c | ||
|
|
8bb26a6b0b | ||
|
|
3f2304549d | ||
|
|
ad72a435f1 | ||
|
|
f34332344e | ||
|
|
d324b57dd7 | ||
|
|
2216bfe875 | ||
|
|
9beefa7473 | ||
|
|
8ebc334889 | ||
|
|
e662c850af | ||
|
|
1e5163e530 | ||
|
|
1567774765 | ||
|
|
babfcbb707 | ||
|
|
027edd86bb | ||
|
|
cc83aadae6 | ||
|
|
8c18660a82 | ||
|
|
4fe61ee25c | ||
|
|
e18b21639c | ||
|
|
1cef03b8c2 | ||
|
|
d60d6dfe99 | ||
|
|
27d086bca2 | ||
|
|
add3f011a0 | ||
|
|
ee90b0b024 | ||
|
|
9bf107866f | ||
|
|
4d2f282950 | ||
|
|
b55fad1b59 | ||
|
|
ba77ff11e9 | ||
|
|
b67aa05d6f | ||
|
|
6b0c45a861 | ||
|
|
dc9623e964 | ||
|
|
3d73d60826 | ||
|
|
9f0c9c3690 | ||
|
|
1a3d3494ce | ||
|
|
b99f620073 | ||
|
|
e2f265b4bc | ||
|
|
251ee57ffd | ||
|
|
7e03104f1c | ||
|
|
f1a258208e | ||
|
|
66cc49313b | ||
|
|
9ae2943f7d | ||
|
|
54326f707b | ||
|
|
3a3b57c15f | ||
|
|
8ea8ad34e6 | ||
|
|
179661a0d4 | ||
|
|
3d22ca1888 | ||
|
|
fdf6798d0c | ||
|
|
9d9a44b927 | ||
|
|
dad935e81d | ||
|
|
a75534ec34 | ||
|
|
eab33de97e | ||
|
|
29de110abb | ||
|
|
2e7f418ee2 | ||
|
|
dadb996d22 | ||
|
|
174f692edf | ||
|
|
f4d5168a20 | ||
|
|
5a438e8435 | ||
|
|
ce4814dc47 | ||
|
|
ef42d0265d | ||
|
|
3c5195028e | ||
|
|
0d5174c453 | ||
|
|
c034c1a986 | ||
|
|
1b49da8748 | ||
|
|
26bda01a28 | ||
|
|
f5008d80ad | ||
|
|
8b464e7ae6 | ||
|
|
78e4a58c91 | ||
|
|
7a4a5eb03e | ||
|
|
d029d56508 | ||
|
|
6411954002 | ||
|
|
7f4ad0d1ca | ||
|
|
4cd4b2914d | ||
|
|
1d55710a0b | ||
|
|
8f646043bb | ||
|
|
4b11a6efcd | ||
|
|
cb3a7c90a8 | ||
|
|
074842a122 | ||
|
|
749ff4a44f | ||
|
|
7d6918ecb0 | ||
|
|
47184c2833 | ||
|
|
6434f1028e | ||
|
|
daade08940 | ||
|
|
a1d289822f | ||
|
|
1ce34f2c74 | ||
|
|
c2dc73a71f | ||
|
|
07bb3b5df8 | ||
|
|
067ef82576 | ||
|
|
59fc98e0c4 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,7 +12,6 @@ test-reports
|
||||
.pytest_cache
|
||||
venv
|
||||
*.noseids
|
||||
build
|
||||
*.egg-info
|
||||
.cache
|
||||
.mypy_cache
|
||||
|
||||
66
README.md
66
README.md
@@ -8,28 +8,43 @@
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://artifacthub.io/packages/search?repo=allegroai)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
<div align="center">
|
||||
|
||||
**v0.16 Upgrade Notice**
|
||||
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
|
||||
|
||||
</div>
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
|
||||
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
|
||||
due to Elasticsearch’s usage of the Java Security Manager.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
|
||||
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
|
||||
|
||||
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
|
||||
(which in any case **does not permit access to data within the Elasticsearch cluster**),
|
||||
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
|
||||
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
|
||||
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
|
||||
|
||||
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
|
||||
|
||||
---
|
||||
|
||||
### ClearML Server
|
||||
## ClearML Server
|
||||
#### *Formerly known as Trains Server*
|
||||
|
||||
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
|
||||
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
|
||||
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
|
||||
|
||||
The **ClearML Server** contains the following components:
|
||||
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
|
||||
## System design
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
The **ClearML Server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
|
||||
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
|
||||
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
|
||||
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
|
||||
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
|
||||
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
|
||||
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
|
||||
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
|
||||
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
|
||||
To have the **ClearML** client use your **ClearML Server** instead:
|
||||
In order to set up the **ClearML** client to work with your **ClearML Server**:
|
||||
- Run the `clearml-init` command for an interactive setup.
|
||||
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
|
||||
|
||||
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
|
||||
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
|
||||
|
||||
## Restarting ClearML Server
|
||||
|
||||
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
```
|
||||
|
||||
1. Configure the ClearML-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
|
||||
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
export CLEARML_HOST_IP=server_host_ip_here
|
||||
export CLEARML_AGENT_GIT_USER=git_username_here
|
||||
export CLEARML_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
|
||||
@@ -205,15 +219,15 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
|
||||
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
|
||||
|
||||
Additionally, you can always find us at *clearml@allegro.ai*
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
301 {
|
||||
_: "moved_permanently"
|
||||
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
|
||||
}
|
||||
|
||||
400 {
|
||||
_: "bad_request"
|
||||
1: ["not_supported", "endpoint is not supported"]
|
||||
@@ -21,6 +26,9 @@
|
||||
23: ["invalid_domain_name", "malformed domain name"]
|
||||
24: ["not_public_object", "object is not public"]
|
||||
|
||||
# Auth / Login
|
||||
75: ["invalid_access_key", "access key not found"]
|
||||
|
||||
# Tasks
|
||||
100: ["task_error", "general task error"]
|
||||
101: ["invalid_task_id", "invalid task id"]
|
||||
@@ -42,6 +50,12 @@
|
||||
130: ["task_not_found", "task not found"]
|
||||
131: ["events_not_added", "events not added"]
|
||||
|
||||
# Reports
|
||||
150: ["operation_supported_on_reports_only", "passed task is not report"]
|
||||
|
||||
# Pipelines
|
||||
160: ["cannot_remove_all_runs", "at least one pipeline run should be left"]
|
||||
|
||||
# Models
|
||||
200: ["model_error", "general task error"]
|
||||
201: ["invalid_model_id", "invalid model id"]
|
||||
@@ -62,6 +76,14 @@
|
||||
402: ["project_has_tasks", "project has associated tasks"]
|
||||
403: ["project_not_found", "project not found"]
|
||||
405: ["project_has_models", "project has associated models"]
|
||||
406: ["project_has_datasets", "project has associated non-empty datasets"]
|
||||
407: ["invalid_project_name", "invalid project name"]
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
|
||||
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
|
||||
413: ["project_has_pipelines", "project has associated pipelines with active controllers"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
@@ -75,7 +97,7 @@
|
||||
|
||||
# Database
|
||||
800: ["data_validation_error", "data validation error"]
|
||||
801: ["expected_unique_data", "value combination already exists"]
|
||||
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
|
||||
|
||||
# Workers
|
||||
1001: ["invalid_worker_id", "invalid worker id"]
|
||||
@@ -108,6 +130,11 @@
|
||||
21: ["no_write_permission", "forbidden (modification not allowed)"]
|
||||
}
|
||||
|
||||
410: {
|
||||
_: "gone"
|
||||
1: ["not_supported", "thus endpoint is not supported any more"]
|
||||
}
|
||||
|
||||
500 {
|
||||
_: "server_error"
|
||||
0: ["general_error", "general server error"]
|
||||
|
||||
@@ -61,10 +61,11 @@ class ListField(fields.ListField):
|
||||
item.validate()
|
||||
|
||||
|
||||
# since there is no distinction between None and empty DictField
|
||||
# this value can be used as sentinel in order to distinguish
|
||||
# between not set and empty DictField
|
||||
DictFieldNotSet = {}
|
||||
class ScalarField(fields.BaseField):
|
||||
|
||||
"""String field."""
|
||||
|
||||
types = (str, int, float, bool)
|
||||
|
||||
|
||||
class DictField(fields.BaseField):
|
||||
@@ -218,7 +219,7 @@ class ActualEnumField(fields.StringField):
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value is None and not self.required:
|
||||
if value is NotSet and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
|
||||
@@ -75,11 +75,17 @@ class CreateUserResponse(Base):
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
@@ -90,6 +96,11 @@ class GetCredentialsResponse(Base):
|
||||
credentials = ListField(CredentialsResponse)
|
||||
|
||||
|
||||
class EditCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class RevokeCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
|
||||
|
||||
25
apiserver/apimodels/batch.py
Normal file
25
apiserver/apimodels/batch.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
|
||||
|
||||
class BatchRequest(Base):
|
||||
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class BatchResponse(Base):
|
||||
succeeded: Sequence[dict] = ListField([dict])
|
||||
failed: Sequence[dict] = ListField([dict])
|
||||
|
||||
|
||||
class UpdateBatchItem(UpdateResponse):
|
||||
id: str = StringField()
|
||||
|
||||
|
||||
class UpdateBatchResponse(BatchResponse):
|
||||
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)
|
||||
@@ -2,7 +2,7 @@ from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
@@ -14,12 +14,19 @@ from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class MetricVariants(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@@ -29,19 +36,22 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
Length(
|
||||
minimum_value=1,
|
||||
maximum_value=config.get(
|
||||
"services.tasks.multi_task_histogram_limit", 10
|
||||
"services.tasks.multi_task_histogram_limit", 100
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
class MetricEventsRequest(Base):
|
||||
metrics: Sequence[TaskMetric] = ListField(
|
||||
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
@@ -49,24 +59,36 @@ class DebugImagesRequest(Base):
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
|
||||
|
||||
class TaskMetricVariant(Base):
|
||||
class GetVariantSampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
|
||||
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
class GetMetricSamplesRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextHistorySampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
next_iteration: bool = BoolField(default=False)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
@@ -74,14 +96,36 @@ class LogOrderEnum(StringEnum):
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
class TaskEventsRequestBase(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
|
||||
|
||||
class TaskEventsRequest(TaskEventsRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField()
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
@@ -89,17 +133,55 @@ class IterationEvents(Base):
|
||||
|
||||
class MetricEvents(Base):
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
class DebugImageResponse(Base):
|
||||
class MetricEventsResponse(Base):
|
||||
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class TaskMetricsRequest(Base):
|
||||
class MultiTasksRequestBase(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class TaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class MultiTaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
|
||||
|
||||
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class ClearTaskLogRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
threshold_sec = IntField()
|
||||
allow_locked = BoolField(default=False)
|
||||
|
||||
@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
|
||||
|
||||
|
||||
class GetSupportedModesRequest(Base):
|
||||
state = StringField(help_text="ASCII base64 encoded application state")
|
||||
callback_url_prefix = StringField()
|
||||
pass
|
||||
# state = StringField(help_text="ASCII base64 encoded application state")
|
||||
# callback_url_prefix = StringField()
|
||||
|
||||
|
||||
class BasicGuestMode(Base):
|
||||
@@ -31,3 +32,4 @@ class GetSupportedModesResponse(Base):
|
||||
server_errors = EmbeddedField(ServerErrors)
|
||||
sso = DictField([str, type(None)])
|
||||
sso_providers = ListField([dict])
|
||||
authenticated = BoolField(default=False)
|
||||
|
||||
24
apiserver/apimodels/metadata.py
Normal file
24
apiserver/apimodels/metadata.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class MetadataItem(Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
class DeleteMetadata(Base):
|
||||
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
|
||||
|
||||
|
||||
class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
replace_metadata = BoolField(default=False)
|
||||
@@ -3,7 +3,12 @@ from six import string_types
|
||||
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
from apiserver.apimodels.batch import BatchRequest
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetFrameworksRequest(models.Base):
|
||||
@@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types+(int,))
|
||||
labels = DictField(value_types=string_types + (int,))
|
||||
tags = ListField(items_types=string_types)
|
||||
system_tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
@@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
@@ -32,17 +38,47 @@ class CreateModelResponse(models.Base):
|
||||
created = fields.BoolField(required=True)
|
||||
|
||||
|
||||
class PublishModelRequest(models.Base):
|
||||
class ModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class DeleteModelRequest(ModelRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_external_artifacts = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ModelsDeleteManyRequest(BatchRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_external_artifacts = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class PublishModelRequest(ModelRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ModelTaskPublishResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
data = fields.EmbeddedField(TaskPublishResponse)
|
||||
data = fields.EmbeddedField(UpdateResponse)
|
||||
|
||||
|
||||
class PublishModelResponse(UpdateResponse):
|
||||
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
||||
updated = fields.IntField()
|
||||
|
||||
|
||||
class ModelsPublishManyRequest(BatchRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class ModelsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
from enum import auto
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import fields, models
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
@@ -9,3 +16,47 @@ class Filter(models.Base):
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
|
||||
|
||||
class EntitiesCountRequest(models.Base):
|
||||
projects = DictField()
|
||||
tasks = DictField()
|
||||
models = DictField()
|
||||
pipelines = DictField()
|
||||
datasets = DictField()
|
||||
reports = DictField()
|
||||
active_users = fields.ListField(str)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class EntityType(StringEnum):
|
||||
task = auto()
|
||||
model = auto()
|
||||
|
||||
|
||||
class ValueMapping(models.Base):
|
||||
key = ScalarField(nullable=True)
|
||||
value = ScalarField(nullable=True)
|
||||
|
||||
|
||||
class FieldMapping(models.Base):
|
||||
field = fields.StringField(required=True)
|
||||
name = fields.StringField()
|
||||
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
|
||||
|
||||
|
||||
class PrepareDownloadForGetAllRequest(models.Base):
|
||||
entity_type = ActualEnumField(EntityType)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
only_fields = fields.ListField(
|
||||
items_types=[str], validators=[Length(1)], required=True
|
||||
)
|
||||
field_mappings: Sequence[FieldMapping] = fields.ListField(
|
||||
items_types=[FieldMapping], validators=[Length(1)], required=True
|
||||
)
|
||||
|
||||
|
||||
class DownloadForGetAllRequest(models.Base):
|
||||
prepare_id = fields.StringField(required=True)
|
||||
|
||||
21
apiserver/apimodels/pipelines.py
Normal file
21
apiserver/apimodels/pipelines.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class Arg(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
value = fields.StringField(required=True)
|
||||
|
||||
|
||||
class DeleteRunsRequest(models.Base):
|
||||
project = fields.StringField(required=True)
|
||||
ids = ListField([str], required=True, validators=[Length(1)])
|
||||
|
||||
|
||||
class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
verify_watched_queue = fields.BoolField(default=False)
|
||||
@@ -1,15 +1,42 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
class ProjectRequest(models.Base):
|
||||
project = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MergeRequest(ProjectRequest):
|
||||
destination_project = fields.StringField()
|
||||
|
||||
|
||||
class MoveRequest(ProjectRequest):
|
||||
new_location = fields.StringField()
|
||||
|
||||
|
||||
class DeleteRequest(ProjectRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
delete_contents = fields.BoolField(default=False)
|
||||
delete_external_artifacts = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectOrNoneRequest(models.Base):
|
||||
project = fields.StringField()
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamReq(ProjectReq):
|
||||
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
|
||||
model_metrics = fields.BoolField(default=False)
|
||||
ids = fields.ListField(str)
|
||||
|
||||
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@@ -18,7 +45,59 @@ class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(ProjectReq):
|
||||
projects = ListField(str)
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
class MultiProjectRequest(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(MultiProjectRequest):
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
task_name = fields.StringField()
|
||||
|
||||
|
||||
class EntityTypeEnum(StringEnum):
|
||||
task = auto()
|
||||
model = auto()
|
||||
|
||||
|
||||
class ProjectUserNamesRequest(MultiProjectRequest):
|
||||
entity = ActualEnumField(EntityTypeEnum, default=EntityTypeEnum.task)
|
||||
|
||||
|
||||
class MultiProjectPagedRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
pattern = fields.StringField()
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):
|
||||
key = fields.StringField(required=True)
|
||||
|
||||
|
||||
class ProjectChildrenType(Enum):
|
||||
pipeline = "pipeline"
|
||||
report = "report"
|
||||
dataset = "dataset"
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_dataset_stats = fields.BoolField(default=False)
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False) # legacy, use allow_public instead
|
||||
active_users = fields.ListField(str)
|
||||
check_own_contents = fields.BoolField(default=False)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
children_type = ActualEnumField(ProjectChildrenType)
|
||||
children_tags = fields.ListField(str)
|
||||
children_tags_filter = DictField()
|
||||
|
||||
@@ -2,7 +2,12 @@ from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetDefaultResp(Base):
|
||||
@@ -14,12 +19,28 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetByIdRequest(QueueRequest):
|
||||
max_task_entries = IntField()
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
max_task_entries = IntField()
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
task = StringField()
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
@@ -28,6 +49,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
@@ -47,6 +69,7 @@ class GetMetricsRequest(Base):
|
||||
from_date = FloatField(required=True, validators=validators.Min(0))
|
||||
to_date = FloatField(required=True, validators=validators.Min(0))
|
||||
interval = IntField(required=True, validators=validators.Min(1))
|
||||
refresh = BoolField(default=False)
|
||||
|
||||
|
||||
class QueueMetrics(Base):
|
||||
@@ -58,3 +81,11 @@ class QueueMetrics(Base):
|
||||
|
||||
class GetMetricsResponse(Base):
|
||||
queues = ListField(QueueMetrics)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
||||
84
apiserver/apimodels/reports.py
Normal file
84
apiserver/apimodels/reports.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, ListField, BoolField, EmbeddedField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels.events import MetricVariants, HistogramRequestBase
|
||||
|
||||
|
||||
class UpdateReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
name = StringField(nullable=True, validators=Length(minimum_value=3))
|
||||
tags = ListField(items_types=[str])
|
||||
comment = StringField()
|
||||
report = StringField()
|
||||
report_assets = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateReportRequest(Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=3))
|
||||
tags = ListField(items_types=[str])
|
||||
comment = StringField()
|
||||
report = StringField()
|
||||
project = StringField()
|
||||
report_assets = ListField(items_types=[str])
|
||||
|
||||
|
||||
class PublishReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
message = StringField(default="")
|
||||
|
||||
|
||||
class ArchiveReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
message = StringField(default="")
|
||||
|
||||
|
||||
class ShareReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
share = BoolField(default=True)
|
||||
|
||||
|
||||
class DeleteReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class MoveReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
project = StringField()
|
||||
project_name = StringField()
|
||||
|
||||
|
||||
class EventsRequest(Base):
|
||||
iters = IntField(default=1, validators=validators.Min(1))
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class PlotEventsRequest(EventsRequest):
|
||||
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogram(HistogramRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class SingleValueMetrics(Base):
|
||||
pass
|
||||
|
||||
|
||||
class GetTasksDataRequest(Base):
|
||||
debug_images: EventsRequest = EmbeddedField(EventsRequest)
|
||||
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
|
||||
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
|
||||
ScalarMetricsIterHistogram
|
||||
)
|
||||
single_value_metrics: SingleValueMetrics = EmbeddedField(SingleValueMetrics)
|
||||
allow_public = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
allow_public = BoolField(default=True)
|
||||
@@ -1,16 +1,17 @@
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
|
||||
from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apiserver.apimodels import DictField, ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskType,
|
||||
ArtifactModes,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
@@ -41,50 +42,94 @@ class StartedResponse(UpdateResponse):
|
||||
|
||||
class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class EnqueueBatchItem(UpdateBatchItem):
|
||||
queued: bool = BoolField()
|
||||
|
||||
|
||||
class EnqueueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
dequeued = IntField()
|
||||
|
||||
|
||||
class DequeueBatchItem(UpdateBatchItem):
|
||||
dequeued: bool = BoolField()
|
||||
|
||||
|
||||
class DequeueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
|
||||
|
||||
|
||||
class ResetResponse(UpdateResponse):
|
||||
deleted_indices = ListField(items_types=six.string_types)
|
||||
dequeued = DictField()
|
||||
frames = DictField()
|
||||
events = DictField()
|
||||
model_deleted = IntField()
|
||||
deleted_models = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class ResetBatchItem(UpdateBatchItem):
|
||||
dequeued: bool = BoolField()
|
||||
deleted_models = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class ResetManyResponse(BatchResponse):
|
||||
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class UpdateRequest(TaskRequest):
|
||||
class TaskUpdateRequest(TaskRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateRequest(TaskUpdateRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DequeueRequest(UpdateRequest):
|
||||
remove_from_all_queues = BoolField(default=False)
|
||||
new_status = StringField()
|
||||
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
requirements = DictField(required=True)
|
||||
|
||||
|
||||
class CompletedRequest(UpdateRequest):
|
||||
publish = BoolField(default=False)
|
||||
|
||||
|
||||
class CompletedResponse(UpdateResponse):
|
||||
published = IntField(default=0)
|
||||
|
||||
|
||||
class PublishRequest(UpdateRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishResponse(UpdateResponse):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(models.Base):
|
||||
"""
|
||||
This is a partial description of task can be updated incrementally
|
||||
@@ -104,6 +149,11 @@ class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class TaskInputModel(models.Base):
|
||||
name = StringField()
|
||||
model = StringField()
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
@@ -113,14 +163,15 @@ class CloneRequest(TaskRequest):
|
||||
new_task_project = StringField()
|
||||
new_task_hyperparams = DictField()
|
||||
new_task_configuration = DictField()
|
||||
new_task_container = DictField()
|
||||
new_task_input_models = ListField([TaskInputModel])
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
new_project_name = StringField()
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArtifactId(models.Base):
|
||||
@@ -130,13 +181,15 @@ class ArtifactId(models.Base):
|
||||
)
|
||||
|
||||
|
||||
class DeleteArtifactsRequest(TaskRequest):
|
||||
class DeleteArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
@@ -161,7 +214,7 @@ class ReplaceHyperparams(object):
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
class EditHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
@@ -169,7 +222,6 @@ class EditHyperParamsRequest(TaskRequest):
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
@@ -177,11 +229,10 @@ class HyperParamKey(models.Base):
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
class DeleteHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
@@ -189,7 +240,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
|
||||
|
||||
|
||||
class GetConfigurationNamesRequest(MultiTaskRequest):
|
||||
pass
|
||||
skip_empty = BoolField(default=True)
|
||||
|
||||
|
||||
class Configuration(models.Base):
|
||||
@@ -199,17 +250,15 @@ class Configuration(models.Base):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
class EditConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
class DeleteConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArchiveRequest(MultiTaskRequest):
|
||||
@@ -219,3 +268,73 @@ class ArchiveRequest(MultiTaskRequest):
|
||||
|
||||
class ArchiveResponse(models.Base):
|
||||
archived = IntField()
|
||||
|
||||
|
||||
class TaskBatchRequest(BatchRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
|
||||
|
||||
class StopManyRequest(TaskBatchRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DequeueManyRequest(TaskBatchRequest):
|
||||
remove_from_all_queues = BoolField(default=False)
|
||||
new_status = StringField()
|
||||
|
||||
|
||||
class EnqueueManyRequest(TaskBatchRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
validate_tasks = BoolField(default=False)
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteManyRequest(TaskBatchRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class ResetManyRequest(TaskBatchRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishManyRequest(TaskBatchRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class AddUpdateModelRequest(TaskRequest):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
iteration = IntField()
|
||||
|
||||
|
||||
class ModelItemKey(models.Base):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
|
||||
|
||||
class DeleteModelsRequest(TaskRequest):
|
||||
models: Sequence[ModelItemKey] = ListField(
|
||||
[ModelItemKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetAllReq(models.Base):
|
||||
allow_public = BoolField(default=True)
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateTagsRequest(BatchRequest):
|
||||
add_tags = ListField([str])
|
||||
remove_tags = ListField([str])
|
||||
|
||||
@@ -12,20 +12,21 @@ from jsonmodels.fields import (
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
|
||||
from apiserver.config_repo import config
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
timeout = make_default(
|
||||
IntField, DEFAULT_TIMEOUT
|
||||
)() # registration timeout in seconds (default is 10min)
|
||||
timeout = IntField(
|
||||
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
|
||||
)
|
||||
""" registration timeout in seconds (default is 10min) """
|
||||
queues = ListField(six.string_types) # list of queues this worker listens to
|
||||
|
||||
|
||||
@@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
@@ -96,12 +98,18 @@ class WorkerResponseEntry(WorkerEntry):
|
||||
|
||||
class GetAllRequest(Base):
|
||||
last_seen = IntField(default=3600)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class GetAllResponse(Base):
|
||||
workers = ListField(WorkerResponseEntry)
|
||||
|
||||
|
||||
class GetCountRequest(GetAllRequest):
|
||||
last_seen = IntField(default=0)
|
||||
|
||||
|
||||
class StatsBase(Base):
|
||||
worker_ids = ListField(str)
|
||||
|
||||
|
||||
@@ -2,7 +2,11 @@ from datetime import datetime
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apiserver.apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
@@ -57,9 +61,10 @@ class AuthBLL:
|
||||
api_version=str(ServiceRepo.max_endpoint_version()),
|
||||
server_version=str(get_version()),
|
||||
server_build=str(get_build_number()),
|
||||
feature_set="basic",
|
||||
)
|
||||
|
||||
return GetTokenResponse(token=token.decode("ascii"))
|
||||
return GetTokenResponse(token=token)
|
||||
|
||||
@staticmethod
|
||||
def create_user(request: CreateUserRequest, call: APICall = None) -> str:
|
||||
@@ -144,7 +149,7 @@ class AuthBLL:
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -153,9 +158,11 @@ class AuthBLL:
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
cred = CredModel(
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
|
||||
)
|
||||
user.save()
|
||||
|
||||
|
||||
@@ -1,476 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
name: str = StringField(required=True)
|
||||
recycle_url_marker: str = StringField()
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
name: str = StringField(required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
timestamp: int = IntField(default=0)
|
||||
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugImagesResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
company_id=company_id,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
),
|
||||
state.metrics,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, state: DebugImageEventsScrollState
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
since their states were initialized and reinits these states
|
||||
"""
|
||||
task_ids = set(metric.task for metric in state.metrics)
|
||||
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
|
||||
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return []
|
||||
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
|
||||
)
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
]
|
||||
|
||||
update_times = dict(
|
||||
chain.from_iterable(
|
||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
)
|
||||
)
|
||||
outdated_metrics = [
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if (metric.task, metric.name) in update_times
|
||||
and update_times[metric.task, metric.name] > metric.timestamp
|
||||
]
|
||||
state.metrics = [
|
||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||
*(
|
||||
self._init_metric_states(
|
||||
company_id,
|
||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
self, company_id: str, metrics: Sequence[Tuple[str, str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
tasks = defaultdict(list)
|
||||
for (task, metric) in metrics:
|
||||
tasks[task].append(metric)
|
||||
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
self._init_metric_states_for_task, company_id=company_id
|
||||
),
|
||||
tasks.items(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_scroll_state(variant: dict):
|
||||
"""
|
||||
Return new variant scroll state for the passed variant bucket
|
||||
If the image urls get recycled then fill the last_invalid_iteration field
|
||||
"""
|
||||
state = VariantScrollState(name=variant["key"])
|
||||
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricScrollState(
|
||||
task=task,
|
||||
name=metric["key"],
|
||||
variants=[
|
||||
init_variant_scroll_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
metric: MetricScrollState,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update metric scroll state
|
||||
"""
|
||||
if metric.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and metric.last_min_iter is not None:
|
||||
range_condition = {"lt": metric.last_min_iter}
|
||||
elif not navigate_earlier and metric.last_max_iter is not None:
|
||||
range_condition = {"gt": metric.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
if navigate_earlier:
|
||||
"""
|
||||
When navigating to earlier iterations consider only
|
||||
variants whose invalid iterations border is lower than
|
||||
our starting iteration. For these variants make sure
|
||||
that only events from the valid iterations are returned
|
||||
"""
|
||||
if not metric.last_min_iter:
|
||||
variants = metric.variants
|
||||
else:
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is None
|
||||
or v.last_invalid_iteration < metric.last_min_iter
|
||||
)
|
||||
if not variants:
|
||||
return metric.task, metric.name, []
|
||||
must_conditions.append(
|
||||
{"terms": {"variant": list(v.name for v in variants)}}
|
||||
)
|
||||
else:
|
||||
"""
|
||||
When navigating to later iterations all variants may be relevant.
|
||||
For the variants whose invalid border is higher than our starting
|
||||
iteration make sure that only events from valid iterations are returned
|
||||
"""
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is not None
|
||||
and v.last_invalid_iteration > metric.last_max_iter
|
||||
)
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
if v.last_invalid_iteration is not None
|
||||
]
|
||||
if variants_conditions:
|
||||
must_not_conditions.append({"bool": {"should": variants_conditions}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {"must": must_conditions, "must_not": must_not_conditions}
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {"sort": {"url": {"order": "desc"}}}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
|
||||
return [
|
||||
ev["_source"]
|
||||
for v in variant_buckets
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
]
|
||||
|
||||
iterations = [
|
||||
{
|
||||
"iter": it["key"],
|
||||
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
|
||||
}
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets")
|
||||
]
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
metric.last_max_iter = iterations[0]["iter"]
|
||||
metric.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
# Commented for now since the last invalid iteration is calculated in the beginning
|
||||
# if navigate_earlier and any(
|
||||
# variant.last_invalid_iteration is None for variant in variants
|
||||
# ):
|
||||
# """
|
||||
# Variants validation flags due to recycling can
|
||||
# be set only on navigation to earlier frames
|
||||
# """
|
||||
# iterations = self._update_variants_invalid_iterations(variants, iterations)
|
||||
|
||||
return metric.task, metric.name, iterations
|
||||
|
||||
@staticmethod
|
||||
def _update_variants_invalid_iterations(
|
||||
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
This code is currently not in used since the invalid iterations
|
||||
are calculated during MetricState initialization
|
||||
For variants that do not have recycle url marker set it from the
|
||||
first event
|
||||
For variants that do not have last_invalid_iteration set check if the
|
||||
recycle marker was reached on a certain iteration and set it to the
|
||||
corresponding iteration
|
||||
For variants that have a newly set last_invalid_iteration remove
|
||||
events from the invalid iterations
|
||||
Return the updated iterations list
|
||||
"""
|
||||
variants_lookup = bucketize(variants, attrgetter("name"))
|
||||
for it in iterations:
|
||||
iteration = it["iter"]
|
||||
events_to_remove = []
|
||||
for event in it["events"]:
|
||||
variant = variants_lookup[event["variant"]][0]
|
||||
if (
|
||||
variant.last_invalid_iteration
|
||||
and variant.last_invalid_iteration >= iteration
|
||||
):
|
||||
events_to_remove.append(event)
|
||||
continue
|
||||
event_url = event.get("url")
|
||||
if not variant.recycle_url_marker:
|
||||
variant.recycle_url_marker = event_url
|
||||
elif variant.recycle_url_marker == event_url:
|
||||
variant.last_invalid_iteration = iteration
|
||||
events_to_remove.append(event)
|
||||
if events_to_remove:
|
||||
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||
return [it for it in iterations if it["events"]]
|
||||
@@ -1,375 +0,0 @@
|
||||
import operator
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugSampleHistoryState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
reached_first: bool = BoolField()
|
||||
reached_last: bool = BoolField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugSampleHistoryResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class DebugSampleHistory:
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugSampleHistoryState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_debug_image(
|
||||
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
Get the debug image for next/prev variant on the current iteration
|
||||
If does not exist then try getting image for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = DebugSampleHistoryResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
image = self._get_next_for_current_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
) or self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
if not image:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(image=image, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def _fill_res_and_update_state(
|
||||
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
|
||||
):
|
||||
state.variant = image["variant"]
|
||||
state.iteration = image["iter"]
|
||||
res.event = image
|
||||
var_state = first(s for s in state.variant_states if s.name == state.variant)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or image is found
|
||||
"""
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if cmp(var_state.name, state.variant)
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"metric": state.metric}},
|
||||
{"terms": {"variant": [v.name for v in variants]}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"variant": "desc" if navigate_earlier else "asc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_current_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the image falls in invalid range are discarded
|
||||
If no suitable image is found then None is returned
|
||||
"""
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"metric": state.metric}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = state.variant_states
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
]
|
||||
must_conditions.append({"bool": {"should": variants_conditions}})
|
||||
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_another_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_debug_image_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
Get the debug image for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = DebugSampleHistoryResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugSampleHistoryState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugSampleHistoryState):
|
||||
if state_.task != task or state_.metric != metric:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugSampleHistoryState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(s for s in state.variant_states if s.name == variant)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_for_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
image=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
|
||||
variant_iterations = self._get_variant_iterations(
|
||||
company_id=company_id, task=state.task, metric=state.metric
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for var_name, min_iter, max_iter in variant_iterations
|
||||
]
|
||||
|
||||
def _get_variant_iterations(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variants: Optional[Sequence[str]] = None,
|
||||
) -> Sequence[Tuple[str, int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported images
|
||||
The min iteration is the lowest iteration that contains non-recycled image url
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if variants:
|
||||
must.append({"terms": {"variant": variants}})
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
# all variants that sent debug images
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_iterations"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(
|
||||
es_res, ("aggregations", "variants", "buckets")
|
||||
)
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,15 @@
|
||||
import base64
|
||||
import zlib
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, Mapping, Tuple
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
@@ -16,7 +21,14 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
TaskCompanies = Mapping[str, Sequence[Task]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
_max_es_allowed_aggregation_buckets = 10000
|
||||
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
|
||||
@@ -28,22 +40,31 @@ class EventSettings:
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def max_metrics_count(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_count", 100)
|
||||
|
||||
@classproperty
|
||||
def max_variants_count(self):
|
||||
return config.get("services.events.events_retrieval.max_variants_count", 100)
|
||||
def max_es_buckets(self):
|
||||
percentage = (
|
||||
min(
|
||||
100,
|
||||
config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count_threshold",
|
||||
80,
|
||||
),
|
||||
)
|
||||
/ 100
|
||||
)
|
||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id}"
|
||||
if isinstance(company_id, str):
|
||||
company_id = [company_id]
|
||||
|
||||
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
if not es.indices.exists(es_index):
|
||||
if not es.indices.exists(index=es_index):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -63,4 +84,83 @@ def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
|
||||
|
||||
|
||||
def count_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.count(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_max_metric_and_variant_counts(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
query: dict,
|
||||
**kwargs,
|
||||
) -> Tuple[int, int]:
|
||||
dynamic = config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count", False
|
||||
)
|
||||
max_metrics_count = config.get(
|
||||
"services.events.events_retrieval.max_metrics_count", 100
|
||||
)
|
||||
max_variants_count = config.get(
|
||||
"services.events.events_retrieval.max_variants_count", 100
|
||||
)
|
||||
if not dynamic:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
|
||||
)
|
||||
|
||||
metrics_count = safe_get(
|
||||
es_res, "aggregations/metrics_count/value", max_metrics_count
|
||||
)
|
||||
if not metrics_count:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
|
||||
|
||||
|
||||
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"variant": variants}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if variants
|
||||
else {"term": {"metric": metric}}
|
||||
for metric, variants in metric_variants.items()
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
|
||||
|
||||
def uncompress_plot(event: dict):
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
|
||||
@@ -4,23 +4,26 @@ from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Mapping
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
SINGLE_SCALAR_ITERATION,
|
||||
TaskCompanies,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -34,7 +37,12 @@ class EventMetrics:
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
@@ -46,7 +54,12 @@ class EventMetrics:
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
@@ -57,6 +70,7 @@ class EventMetrics:
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
@@ -64,6 +78,7 @@ class EventMetrics:
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
@@ -94,57 +109,51 @@ class EventMetrics:
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
companies: TaskCompanies,
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company", "company_origin"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.get_index_company() for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
event_type = EventType.metrics_scalar
|
||||
company_id = next(iter(companies))
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
companies = {
|
||||
company_id: tasks
|
||||
for company_id, tasks in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return {}
|
||||
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
run_parallel=False,
|
||||
)
|
||||
task_ids, company_ids = zip(
|
||||
*(
|
||||
(t.id, t.company)
|
||||
for t in itertools.chain.from_iterable(companies.values())
|
||||
)
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
|
||||
)
|
||||
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
task_name = task_names[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
@@ -152,6 +161,75 @@ class EventMetrics:
|
||||
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self,
|
||||
companies: TaskCompanies,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
companies = {
|
||||
company_id: [t.id for t in tasks]
|
||||
for company_id, tasks in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_events = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_single_value_metrics,
|
||||
metric_variants=metric_variants,
|
||||
),
|
||||
companies.items(),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _get_value(event: dict):
|
||||
return {
|
||||
field: event.get(field)
|
||||
for field in ("metric", "variant", "value", "timestamp")
|
||||
}
|
||||
|
||||
return {
|
||||
task: [_get_value(e) for e in events]
|
||||
for task, events in bucketize(task_events, itemgetter("task")).items()
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
|
||||
) -> Sequence[dict]:
|
||||
company_id, task_ids = tasks
|
||||
must = [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_scalar,
|
||||
)
|
||||
if not es_res["hits"]["total"]["value"]:
|
||||
return []
|
||||
|
||||
return [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@@ -197,6 +275,7 @@ class EventMetrics:
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
@@ -204,21 +283,31 @@ class EventMetrics:
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = self._task_conditions(task_id)
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@@ -232,10 +321,7 @@ class EventMetrics:
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
@@ -287,33 +373,41 @@ class EventMetrics:
|
||||
"""
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
aggs=aggs,
|
||||
task_id=task_id,
|
||||
metrics=metrics,
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
@@ -340,19 +434,20 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
aggs: dict,
|
||||
@staticmethod
|
||||
def _task_conditions(task_id: str) -> list:
|
||||
return [
|
||||
{"term": {"task": task_id}},
|
||||
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_task_metrics_query(
|
||||
cls,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
):
|
||||
must = cls._task_conditions(task_id)
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
@@ -367,25 +462,98 @@ class EventMetrics:
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
}
|
||||
return {"bool": {"must": must}}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
|
||||
"""
|
||||
For the requested tasks return reported metrics and variants
|
||||
"""
|
||||
tasks_ids = {
|
||||
company: [t.id for t in tasks]
|
||||
for company, tasks in companies.items()
|
||||
}
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
companies_res: Sequence = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_multi_task_metrics,
|
||||
event_type=event_type,
|
||||
),
|
||||
tasks_ids.items(),
|
||||
)
|
||||
)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
if len(companies_res) == 1:
|
||||
return companies_res[0]
|
||||
|
||||
def get_tasks_metrics(
|
||||
res = defaultdict(set)
|
||||
for c_res in companies_res:
|
||||
for m, vars_ in c_res.items():
|
||||
res[m].update(vars_)
|
||||
|
||||
return {
|
||||
k: list(v)
|
||||
for k, v in res.items()
|
||||
}
|
||||
|
||||
def _get_multi_task_metrics(
|
||||
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
|
||||
) -> Mapping[str, list]:
|
||||
company_id, task_ids = company_tasks
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
query = QueryBuilder.terms("task", task_ids)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
**search_args,
|
||||
)
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
return {
|
||||
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
|
||||
for mb in aggs_result["metrics"]["buckets"]
|
||||
}
|
||||
|
||||
def get_task_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
For the requested tasks return reported metrics per task
|
||||
"""
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
@@ -406,22 +574,21 @@ class EventMetrics:
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": self._task_conditions(task_id)}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": EventSettings.max_es_buckets,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
||||
197
apiserver/bll/event/events_iterator.py
Normal file
197
apiserver/bll/event/events_iterator.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from typing import Optional, Tuple, Sequence, Any
|
||||
|
||||
import attr
|
||||
import jsonmodels.models
|
||||
import jwt
|
||||
from elasticsearch import Elasticsearch
|
||||
from jwt.algorithms import get_default_algorithms
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
count_company_events,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class EventsIterator:
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_key_value: Optional[Any] = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
|
||||
**kwargs,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
from_key_value = kwargs.pop("from_timestamp", from_key_value)
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
event_type=event_type,
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=ScalarKey.resolve(key),
|
||||
)
|
||||
return res
|
||||
|
||||
def count_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_ids: Sequence[str],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return 0
|
||||
|
||||
query, _ = self._get_initial_query_and_must(task_ids, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_result = count_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_result["count"]
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
key: ScalarKey,
|
||||
from_key_value: Optional[Any],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If from_key_field is not set then start either from latest or earliest.
|
||||
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
|
||||
so that events with this value will not be lost between the calls.
|
||||
"""
|
||||
query, must = self._get_initial_query_and_must([task_id], metric_variants)
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": query,
|
||||
"sort": {key.field: "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_key_value:
|
||||
es_req["search_after"] = [from_key_value]
|
||||
|
||||
with translate_errors_context():
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": must + [{"term": {key.field: events[-1][key.field]}}]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_initial_query_and_must(
|
||||
task_ids: Sequence[str], metric_variants: MetricVariants = None
|
||||
) -> Tuple[dict, list]:
|
||||
if not metric_variants:
|
||||
query = {"terms": {"task": task_ids}}
|
||||
must = [query]
|
||||
else:
|
||||
must = [
|
||||
{"terms": {"task": task_ids}},
|
||||
get_metric_variants_condition(metric_variants),
|
||||
]
|
||||
query = {"bool": {"must": must}}
|
||||
return query, must
|
||||
|
||||
|
||||
class Scroll(jsonmodels.models.Base):
|
||||
def get_scroll_id(self) -> str:
|
||||
return jwt.encode(
|
||||
self.to_struct(),
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_scroll_id(cls, scroll_id: str):
|
||||
try:
|
||||
return cls(
|
||||
**jwt.decode(
|
||||
scroll_id,
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
algorithms=get_default_algorithms(),
|
||||
)
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise ValueError("Invalid Scroll ID")
|
||||
455
apiserver/bll/event/history_debug_image_iterator.py
Normal file
455
apiserver/bll/event/history_debug_image_iterator.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, ListField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugImageSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class VariantSampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryDebugImageIterator:
|
||||
event_type = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = VariantSampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if next_iteration:
|
||||
event = self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
else:
|
||||
# noinspection PyArgumentList
|
||||
event = first(
|
||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||
for f in (
|
||||
self._get_next_for_current_iteration,
|
||||
self._get_next_for_another_iteration,
|
||||
)
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
event: dict, res: VariantSampleResult, state: DebugImageSampleState
|
||||
):
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@staticmethod
|
||||
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
|
||||
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in metric_variants
|
||||
]
|
||||
return {"bool": {"should": variants_conditions}}
|
||||
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
_get_variants_conditions(metric_variants),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, metric_variants in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = VariantSampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugImageSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugImageSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugImageSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
316
apiserver/bll/event/history_plots_iterator.py
Normal file
316
apiserver/bll/event/history_plots_iterator.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, ListField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import (
|
||||
EventType,
|
||||
uncompress_plot,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
name: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class PlotsSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
metric_states: Sequence[MetricState] = ListField([MetricState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricSamplesResult(object):
|
||||
scroll_id: str = None
|
||||
events: list = []
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryPlotsIterator:
|
||||
event_type = EventType.metrics_plot
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=PlotsSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the samples for next/prev metric on the current iteration
|
||||
If does not exist then try getting sample for the first/last metric from next/prev iteration
|
||||
"""
|
||||
res = MetricSamplesResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
]
|
||||
if state.navigate_current_metric:
|
||||
must_conditions.append({"term": {"metric": state.metric}})
|
||||
|
||||
next_iteration_condition = {
|
||||
"range": {"iter": {range_operator: state.iteration}}
|
||||
}
|
||||
if next_iteration or state.navigate_current_metric:
|
||||
must_conditions.append(next_iteration_condition)
|
||||
else:
|
||||
next_metric_condition = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"range": {"metric": {range_operator: state.metric}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
must_conditions.append(
|
||||
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
|
||||
)
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order=order,
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def get_samples_for_metric(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = MetricSamplesResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: PlotsSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: PlotsSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
state: PlotsSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
|
||||
if not metric_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order="desc",
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
return res
|
||||
|
||||
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
|
||||
metrics = self._get_metric_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.metric_states = [
|
||||
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for metric, (min_iter, max_iter) in metrics.items()
|
||||
]
|
||||
|
||||
def _get_metric_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 5000,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
)
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: (
|
||||
int(metric_bucket["first_iter"]["value"]),
|
||||
int(metric_bucket["last_iter"]["value"]),
|
||||
)
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
|
||||
):
|
||||
for event in events:
|
||||
uncompress_plot(event)
|
||||
state.metric = events[0]["metric"]
|
||||
state.iteration = events[0]["iter"]
|
||||
res.events = events
|
||||
metric_state = first(
|
||||
ms for ms in state.metric_states if ms.name == state.metric
|
||||
)
|
||||
if metric_state:
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
def _get_metric_events_for_condition(
|
||||
self, company_id: str, task: str, order: str, must_conditions: Sequence
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1,
|
||||
"order": {"_key": order},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"variant": {"order": "asc"}},
|
||||
"size": 100,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
for level in ("iters", "metrics"):
|
||||
level_data = aggs_result[level]["buckets"]
|
||||
if not level_data:
|
||||
return []
|
||||
aggs_result = level_data[0]
|
||||
|
||||
return [
|
||||
hit["_source"]
|
||||
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
|
||||
]
|
||||
@@ -1,127 +0,0 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = EventType.task_log
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
53
apiserver/bll/event/metric_debug_images_iterator.py
Normal file
53
apiserver/bll/event/metric_debug_images_iterator.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import EventType
|
||||
from .metric_events_iterator import MetricEventsIterator, VariantState
|
||||
|
||||
|
||||
class MetricDebugImagesIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_image)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return [{"exists": {"field": "url"}}]
|
||||
|
||||
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
aggs = {
|
||||
"urls": {
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "desc"},
|
||||
"size": 1, # we need only one url from the most recent iteration
|
||||
},
|
||||
"aggs": {
|
||||
"max_iter": {"max": {"field": "iter"}},
|
||||
"iters": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": 2, # need two last iterations so that we can take
|
||||
# the second one as invalid
|
||||
"_source": "iter",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
|
||||
"""If the image urls get recycled then fill the last_invalid_iteration field"""
|
||||
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
|
||||
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
|
||||
if len(iters) > 1:
|
||||
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
|
||||
|
||||
return aggs, fill_variant_state_data
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"url": {"order": "desc"}}
|
||||
451
apiserver/bll/event/metric_events_iterator.py
Normal file
451
apiserver/bll/event/metric_events_iterator.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import abc
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Callable
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
variant: str = StringField(required=True)
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[VariantState] = ListField([VariantState], required=True)
|
||||
timestamp: int = IntField(default=0)
|
||||
|
||||
|
||||
class TaskScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class MetricEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricEventsResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class MetricEventsIterator:
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=MetricEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
companies: Mapping[str, str],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> MetricEventsResult:
|
||||
companies = {
|
||||
task_id: company_id
|
||||
for task_id, company_id in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=self.event_type
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return MetricEventsResult()
|
||||
|
||||
def init_state(state_: MetricEventsScrollState):
|
||||
state_.tasks = self._init_task_states(companies, task_metrics)
|
||||
|
||||
def validate_state(state_: MetricEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if refresh:
|
||||
self._reinit_outdated_task_states(companies, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = MetricEventsResult(next_scroll_id=state.id)
|
||||
specific_variants_requested = any(
|
||||
variants
|
||||
for t, metrics in task_metrics.items()
|
||||
if metrics
|
||||
for m, variants in metrics.items()
|
||||
)
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
companies=companies,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
specific_variants_requested=specific_variants_requested,
|
||||
),
|
||||
state.tasks,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
companies: Mapping[str, str],
|
||||
state: MetricEventsScrollState,
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new event_type events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
|
||||
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
) -> Mapping[str, datetime]:
|
||||
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return {}
|
||||
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return {
|
||||
stats.metric: stats.event_stats_by_type[
|
||||
self.event_type.value
|
||||
].last_update
|
||||
for stats in metric_stats.values()
|
||||
if self.event_type.value in stats.event_stats_by_type
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
}
|
||||
|
||||
update_times = {
|
||||
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
}
|
||||
task_metric_states = {
|
||||
task_state.task: {
|
||||
metric_state.metric: metric_state for metric_state in task_state.metrics
|
||||
}
|
||||
for task_state in state.tasks
|
||||
}
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
) -> TaskScrollState:
|
||||
task = old_state.task
|
||||
updated_state = first(uts for uts in updates if uts.task == task)
|
||||
if not updated_state:
|
||||
old_state.reset()
|
||||
return old_state
|
||||
|
||||
updated_metrics = [m.metric for m in updated_state.metrics]
|
||||
return TaskScrollState(
|
||||
task=task,
|
||||
metrics=[
|
||||
*updated_state.metrics,
|
||||
*(
|
||||
old_metric
|
||||
for old_metric in old_state.metrics
|
||||
if old_metric.metric not in updated_metrics
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
state.tasks = [
|
||||
merge_with_updated_task_states(task_state, updated_task_states)
|
||||
for task_state in state.tasks
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, companies=companies),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
return [
|
||||
TaskScrollState(task=task, metrics=metric_states,)
|
||||
for task, metric_states in zip(task_metrics, task_metric_states)
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variant_state_aggs(
|
||||
self,
|
||||
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
pass
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any event_type events
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
company_id = companies[task]
|
||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
**(
|
||||
{"aggs": variant_state_aggs}
|
||||
if variant_state_aggs
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_state(variant: dict):
|
||||
"""
|
||||
Return new variant state for the passed variant bucket
|
||||
"""
|
||||
state = VariantState(variant=variant["key"])
|
||||
if fill_variant_state_data:
|
||||
fill_variant_state_data(variant, state)
|
||||
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricState(
|
||||
metric=metric["key"],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
variants=[
|
||||
init_variant_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
pass
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
task_state: TaskScrollState,
|
||||
companies: Mapping[str, str],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
specific_variants_requested: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update task scroll state
|
||||
"""
|
||||
if not task_state.metrics:
|
||||
return task_state.task, []
|
||||
|
||||
if task_state.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task_state.task}},
|
||||
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and task_state.last_min_iter is not None:
|
||||
range_condition = {"lt": task_state.last_min_iter}
|
||||
elif not navigate_earlier and task_state.last_max_iter is not None:
|
||||
range_condition = {"gt": task_state.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
metrics_count = len(task_state.metrics)
|
||||
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": self._get_same_variant_events_order()
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=companies[task_state.task],
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return task_state.task, []
|
||||
|
||||
invalid_iterations = {
|
||||
(m.metric, v.variant): v.last_invalid_iteration
|
||||
for m in task_state.metrics
|
||||
for v in m.variants
|
||||
}
|
||||
allow_uninitialized = (
|
||||
False
|
||||
if specific_variants_requested
|
||||
else config.get(
|
||||
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
def is_valid_event(event: dict) -> bool:
|
||||
key = event.get("metric"), event.get("variant")
|
||||
if key not in invalid_iterations:
|
||||
return allow_uninitialized
|
||||
|
||||
max_invalid = invalid_iterations[key]
|
||||
return max_invalid is None or event.get("iter") > max_invalid
|
||||
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
self._process_event(ev["_source"])
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
if is_valid_event(ev["_source"])
|
||||
]
|
||||
|
||||
iterations = []
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets"):
|
||||
events = get_iteration_events(it)
|
||||
if events:
|
||||
iterations.append({"iter": it["key"], "events": events})
|
||||
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
task_state.last_max_iter = iterations[0]["iter"]
|
||||
task_state.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
return task_state.task, iterations
|
||||
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .metric_events_iterator import MetricEventsIterator
|
||||
|
||||
|
||||
class MetricPlotsIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variant_state_aggs(self):
|
||||
return None, None
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"timestamp": {"order": "desc"}}
|
||||
@@ -4,8 +4,10 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from typing import Any
|
||||
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.bll.util import extract_properties_to_lists
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
def cast_value(self, value: Any) -> Any:
|
||||
"""Cast value to appropriate type"""
|
||||
return value
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,268 @@
|
||||
from typing import Optional, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple, Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.utils import get_company_or_none_constraint
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
|
||||
"""
|
||||
Return the list of unique frameworks used by company and public models
|
||||
If project ids passed then only models from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
event_bll = None
|
||||
|
||||
@classmethod
|
||||
def get_company_model_by_id(
|
||||
cls, company_id: str, model_id: str, only_fields=None
|
||||
) -> Model:
|
||||
query = dict(company=company_id, id=model_id)
|
||||
qs = Model.objects(**query)
|
||||
if only_fields:
|
||||
qs = qs.only(*only_fields)
|
||||
model = qs.first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id, model_ids, only=None, allow_public=False, return_models=True,
|
||||
) -> Optional[Sequence[Model]]:
|
||||
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
||||
ids = set(model_ids)
|
||||
query = Q(id__in=ids)
|
||||
|
||||
q = Model.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidModelId(ids=model_ids)
|
||||
|
||||
if return_models:
|
||||
return list(q)
|
||||
|
||||
@classmethod
|
||||
def publish_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
||||
|
||||
user_id = identity.user
|
||||
published_task = None
|
||||
if model.task and publish_task_func:
|
||||
task = (
|
||||
Task.objects(id=model.task, company=company_id)
|
||||
.only("id", "status")
|
||||
.first()
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, identity, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
updated = model.update(
|
||||
upsert=False,
|
||||
ready=True,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
return updated, published_task
|
||||
|
||||
@classmethod
|
||||
def delete_model(
|
||||
cls, model_id: str, company_id: str, user_id: str, force: bool, delete_external_artifacts: bool = True,
|
||||
) -> Tuple[int, Model]:
|
||||
model = cls.get_company_model_by_id(
|
||||
company_id=company_id,
|
||||
model_id=model_id,
|
||||
only_fields=("id", "task", "project", "uri"),
|
||||
)
|
||||
deleted_model_id = f"{deleted_prefix}{model_id}"
|
||||
|
||||
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
||||
if using_tasks:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelInUse(
|
||||
"as execution model, use force=True to delete",
|
||||
num_tasks=len(using_tasks),
|
||||
)
|
||||
# update deleted model id in using tasks
|
||||
Task._get_collection().update_many(
|
||||
filter={"_id": {"$in": [t.id for t in using_tasks]}},
|
||||
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if model.task:
|
||||
task = Task.objects(id=model.task).first()
|
||||
if task:
|
||||
now = datetime.utcnow()
|
||||
if task.status == TaskStatus.published:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
Task._get_collection().update_one(
|
||||
filter={"_id": model.task, "models.output.model": model_id},
|
||||
update={
|
||||
"$set": {
|
||||
"models.output.$[elem].model": deleted_model_id,
|
||||
"output.error": f"model deleted on {now.isoformat()}",
|
||||
"last_change": now,
|
||||
"last_changed_by": user_id,
|
||||
},
|
||||
},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
else:
|
||||
task.update(
|
||||
pull__models__output__model=model_id,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user_id,
|
||||
)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", True
|
||||
)
|
||||
if delete_external_artifacts:
|
||||
from apiserver.bll.task.task_cleanup import (
|
||||
collect_debug_image_urls,
|
||||
collect_plot_image_urls,
|
||||
_schedule_for_delete,
|
||||
)
|
||||
urls = set()
|
||||
urls.update(collect_debug_image_urls(company_id, model_id))
|
||||
urls.update(collect_plot_image_urls(company_id, model_id))
|
||||
if model.uri:
|
||||
urls.add(model.uri)
|
||||
if urls:
|
||||
_schedule_for_delete(
|
||||
task_id=model_id,
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
urls=urls,
|
||||
can_delete_folders=False,
|
||||
)
|
||||
|
||||
if not cls.event_bll:
|
||||
from apiserver.bll.event import EventBLL
|
||||
cls.event_bll = EventBLL()
|
||||
|
||||
cls.event_bll.delete_task_events(company_id, model_id, allow_locked=True, model=True)
|
||||
del_count = Model.objects(id=model_id, company=company_id).delete()
|
||||
return del_count, model
|
||||
|
||||
@classmethod
|
||||
def archive_model(cls, model_id: str, company_id: str, user_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
archived = Model.objects(company=company_id, id=model_id).update(
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
return archived
|
||||
|
||||
@classmethod
|
||||
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
unarchived = Model.objects(company=company_id, id=model_id).update(
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
return unarchived
|
||||
|
||||
@classmethod
|
||||
def get_model_stats(
|
||||
cls, company: str, model_ids: Sequence[str],
|
||||
) -> Dict[str, dict]:
|
||||
if not model_ids:
|
||||
return {}
|
||||
|
||||
result = Model.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": {"$in": model_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
||||
}
|
||||
},
|
||||
{"$project": {"labels_count": 1}},
|
||||
]
|
||||
)
|
||||
return {r.pop("_id"): r for r in result}
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
last_update: datetime = None,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
):
|
||||
last_update = last_update or datetime.utcnow()
|
||||
updates = {
|
||||
"last_update": datetime.utcnow(),
|
||||
"last_change": last_update,
|
||||
"last_changed_by": user_id,
|
||||
}
|
||||
if last_iteration_max is not None:
|
||||
updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
get_last_metric_updates(
|
||||
task_id=model_id,
|
||||
last_scalar_events=last_scalar_events,
|
||||
raw_updates=raw_updates,
|
||||
extra_updates=updates,
|
||||
model_events=True,
|
||||
)
|
||||
|
||||
ret = Model.objects(id=model_id).update_one(**updates)
|
||||
if ret and raw_updates:
|
||||
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
|
||||
|
||||
return ret
|
||||
|
||||
107
apiserver/bll/model/metadata.py
Normal file
107
apiserver/bll/model/metadata.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import Sequence, Union, Mapping
|
||||
|
||||
from mongoengine import Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Metadata:
|
||||
@staticmethod
|
||||
def metadata_from_api(
|
||||
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||
) -> dict:
|
||||
if not api_data:
|
||||
return {}
|
||||
|
||||
if isinstance(api_data, dict):
|
||||
return {
|
||||
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||
for k, v in api_data.items()
|
||||
}
|
||||
|
||||
return {
|
||||
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_metadata(
|
||||
cls,
|
||||
obj: Document,
|
||||
items: Sequence[MetadataItem],
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||
for prefix in (
|
||||
"metadata.",
|
||||
"-metadata.",
|
||||
):
|
||||
paths = [
|
||||
cls._process_path(path) if path.startswith(prefix) else path
|
||||
for path in paths
|
||||
]
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call_data: dict) -> dict:
|
||||
if not call_data:
|
||||
return call_data
|
||||
|
||||
keys = list(call_data)
|
||||
call_data = {
|
||||
safe_key: call_data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = GetMixin.get_projection(call_data)
|
||||
if projection:
|
||||
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||
|
||||
ordering = GetMixin.get_ordering(call_data)
|
||||
if ordering:
|
||||
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
@@ -1,12 +1,11 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
from typing import Sequence, Dict, Type
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import AttributedDocument
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
@@ -26,6 +25,51 @@ class OrgBLL:
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def edit_entity_tags(
|
||||
self,
|
||||
company_id,
|
||||
entity_cls: Type[AttributedDocument],
|
||||
entity_ids: Sequence[str],
|
||||
add_tags: Sequence[str],
|
||||
remove_tags: Sequence[str],
|
||||
) -> int:
|
||||
if entity_cls not in (Task, Model):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Tags editing can be called on tasks or models only"
|
||||
)
|
||||
if not entity_ids:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"No entity ids provided for editing tags"
|
||||
)
|
||||
if not (add_tags or remove_tags):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either add tags or remove tags should be provided"
|
||||
)
|
||||
|
||||
updated = 0
|
||||
if add_tags:
|
||||
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
|
||||
add_to_set__tags=add_tags
|
||||
)
|
||||
if remove_tags:
|
||||
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
|
||||
pull_all__tags=remove_tags
|
||||
)
|
||||
if not updated:
|
||||
return 0
|
||||
|
||||
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
|
||||
"project"
|
||||
)
|
||||
update_project_time(project_ids=projects)
|
||||
self.update_tags(
|
||||
company_id,
|
||||
entity=Tags.Task if entity_cls is Task else Tags.Model,
|
||||
projects=projects,
|
||||
tags=add_tags or remove_tags
|
||||
)
|
||||
return updated
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company_id: str,
|
||||
@@ -54,10 +98,10 @@ class OrgBLL:
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company_id, project, tags, system_tags)
|
||||
tags_cache.update_tags(company_id, projects, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
@@ -65,34 +109,3 @@ class OrgBLL:
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
|
||||
@classmethod
|
||||
def get_parent_tasks(
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get list of unique parent tasks sorted by task name for the passed company projects
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
|
||||
parent_ids = set(Task.objects(query).distinct("parent"))
|
||||
if not parent_ids:
|
||||
return []
|
||||
|
||||
parents = Task.get_many_with_join(
|
||||
company_id,
|
||||
query=Q(id__in=parent_ids),
|
||||
allow_public=True,
|
||||
override_projection=("id", "name", "project.name"),
|
||||
)
|
||||
return sorted(parents, key=itemgetter("name"))
|
||||
|
||||
@@ -5,6 +5,8 @@ from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@@ -40,7 +42,9 @@ class _TagsCache:
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
query &= Q(project__in=project_ids_with_children([project]))
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
@@ -103,7 +107,7 @@ class _TagsCache:
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
|
||||
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
@@ -119,7 +123,7 @@ class _TagsCache:
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company_id, projects=[project], fields=fields)
|
||||
self._delete_redis_keys(company_id, projects=projects, fields=fields)
|
||||
|
||||
def reset_tags(self, company_id: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .project_queries import ProjectQueries
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
327
apiserver/bll/project/project_cleanup.py
Normal file
327
apiserver/bll/project/project_cleanup.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Set, Sequence
|
||||
|
||||
import attr
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.task.task_cleanup import (
|
||||
collect_debug_image_urls,
|
||||
collect_plot_image_urls,
|
||||
TaskUrls,
|
||||
_schedule_for_delete,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType, TaskStatus
|
||||
from .project_bll import (
|
||||
ProjectBLL,
|
||||
pipeline_tag,
|
||||
pipelines_project_name,
|
||||
dataset_tag,
|
||||
datasets_project_name,
|
||||
reports_tag,
|
||||
)
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DeleteProjectResult:
|
||||
deleted: int = 0
|
||||
disassociated_tasks: int = 0
|
||||
deleted_models: int = 0
|
||||
deleted_tasks: int = 0
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def _get_child_project_ids(
|
||||
project_id: str,
|
||||
) -> Tuple[Sequence[str], Sequence[str], Sequence[str]]:
|
||||
project_ids = _ids_with_children([project_id])
|
||||
pipeline_ids = list(
|
||||
Project.objects(
|
||||
id__in=project_ids,
|
||||
system_tags__in=[pipeline_tag],
|
||||
basename__ne=pipelines_project_name,
|
||||
).scalar("id")
|
||||
)
|
||||
dataset_ids = list(
|
||||
Project.objects(
|
||||
id__in=project_ids,
|
||||
system_tags__in=[dataset_tag],
|
||||
basename__ne=datasets_project_name,
|
||||
).scalar("id")
|
||||
)
|
||||
return project_ids, pipeline_ids, dataset_ids
|
||||
|
||||
|
||||
def validate_project_delete(company: str, project_id: str):
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
|
||||
ret = {}
|
||||
if pipeline_ids:
|
||||
pipelines_with_active_controllers = Task.objects(
|
||||
project__in=pipeline_ids,
|
||||
type=TaskType.controller,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).distinct("project")
|
||||
ret["pipelines"] = len(pipelines_with_active_controllers)
|
||||
else:
|
||||
ret["pipelines"] = 0
|
||||
if dataset_ids:
|
||||
datasets_with_data = Task.objects(
|
||||
project__in=dataset_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).distinct("project")
|
||||
ret["datasets"] = len(datasets_with_data)
|
||||
else:
|
||||
ret["datasets"] = 0
|
||||
|
||||
project_ids = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
|
||||
if project_ids:
|
||||
in_project_query = Q(project__in=project_ids)
|
||||
for cls in (Task, Model):
|
||||
query = (
|
||||
in_project_query & Q(system_tags__nin=[reports_tag])
|
||||
if cls is Task
|
||||
else in_project_query
|
||||
)
|
||||
ret[f"{cls.__name__.lower()}s"] = cls.objects(query).count()
|
||||
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
|
||||
query & Q(system_tags__nin=[EntityVisibility.archived.value])
|
||||
).count()
|
||||
ret["reports"] = Task.objects(
|
||||
in_project_query & Q(system_tags__in=[reports_tag])
|
||||
).count()
|
||||
ret["non_archived_reports"] = Task.objects(
|
||||
in_project_query
|
||||
& Q(
|
||||
system_tags__in=[reports_tag],
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
)
|
||||
).count()
|
||||
else:
|
||||
for cls in (Task, Model):
|
||||
ret[f"{cls.__name__.lower()}s"] = 0
|
||||
ret[f"non_archived_{cls.__name__.lower()}s"] = 0
|
||||
ret["reports"] = 0
|
||||
ret["non_archived_reports"] = 0
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str,
|
||||
user: str,
|
||||
project_id: str,
|
||||
force: bool,
|
||||
delete_contents: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", True
|
||||
)
|
||||
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
|
||||
if not force:
|
||||
if pipeline_ids:
|
||||
active_controllers = Task.objects(
|
||||
project__in=pipeline_ids,
|
||||
type=TaskType.controller,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).only("id")
|
||||
if active_controllers:
|
||||
raise errors.bad_request.ProjectHasPipelines(
|
||||
"please archive all the controllers or use force=true",
|
||||
id=project_id,
|
||||
)
|
||||
if dataset_ids:
|
||||
datasets_with_data = Task.objects(
|
||||
project__in=dataset_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).only("id")
|
||||
if datasets_with_data:
|
||||
raise errors.bad_request.ProjectHasDatasets(
|
||||
"please delete all the dataset versions or use force=true",
|
||||
id=project_id,
|
||||
)
|
||||
|
||||
regular_projects = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
|
||||
if regular_projects:
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
non_archived = cls.objects(
|
||||
project__in=regular_projects,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).only("id")
|
||||
if non_archived:
|
||||
raise error("use force=true", id=project_id)
|
||||
|
||||
if not delete_contents:
|
||||
disassociated = defaultdict(int)
|
||||
for cls in ProjectBLL.child_classes:
|
||||
disassociated[cls] = cls.objects(project__in=project_ids).update(
|
||||
project=None
|
||||
)
|
||||
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
|
||||
else:
|
||||
deleted_models, model_event_urls, model_urls = _delete_models(
|
||||
company=company, user=user, projects=project_ids
|
||||
)
|
||||
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, user=user, projects=project_ids
|
||||
)
|
||||
event_urls = task_event_urls | model_event_urls
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=project_id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=True,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
res = DeleteProjectResult(
|
||||
deleted_tasks=deleted_tasks,
|
||||
deleted_models=deleted_models,
|
||||
urls=TaskUrls(
|
||||
model_urls=list(model_urls),
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
),
|
||||
)
|
||||
|
||||
affected = {*project_ids, *(project.path or [])}
|
||||
res.deleted = Project.objects(id__in=project_ids).delete()
|
||||
|
||||
return res, affected
|
||||
|
||||
|
||||
def _delete_tasks(
|
||||
company: str, user: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set, Set]:
|
||||
"""
|
||||
Delete only the task themselves and their non published version.
|
||||
Child models under the same project are deleted separately.
|
||||
Children tasks should be deleted in the same api call.
|
||||
If any child entities are left in another projects then updated their parent task to None
|
||||
"""
|
||||
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
|
||||
if not tasks:
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = list({t.id for t in tasks})
|
||||
now = datetime.utcnow()
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(
|
||||
parent=None,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(
|
||||
task=None,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
|
||||
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
|
||||
company, task_ids
|
||||
)
|
||||
artifact_urls = set()
|
||||
for task in tasks:
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls.update(
|
||||
{
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
)
|
||||
|
||||
event_bll.delete_multi_task_events(company, task_ids)
|
||||
deleted = tasks.delete()
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(
|
||||
company: str, user: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set[str], Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
that reference them to reference None.
|
||||
"""
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set(), set()
|
||||
|
||||
model_ids = list({m.id for m in models})
|
||||
deleted = "__DELETED__"
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": deleted}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
now = datetime.utcnow()
|
||||
# update published tasks
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
"status": TaskStatus.published,
|
||||
},
|
||||
update={
|
||||
"$set": {
|
||||
"models.output.$[elem].model": deleted,
|
||||
"last_change": now,
|
||||
"last_changed_by": user,
|
||||
}
|
||||
},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
# update unpublished tasks
|
||||
Task.objects(
|
||||
id__in=model_tasks,
|
||||
project__nin=projects,
|
||||
status__ne=TaskStatus.published,
|
||||
).update(
|
||||
pull__models__output__model__in=model_ids,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user,
|
||||
)
|
||||
|
||||
event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
|
||||
company, model_ids
|
||||
)
|
||||
model_urls = {m.uri for m in models if m.uri}
|
||||
|
||||
event_bll.delete_multi_task_events(company, model_ids, model=True)
|
||||
deleted = models.delete()
|
||||
return deleted, event_urls, model_urls
|
||||
407
apiserver/bll/project/project_queries.py
Normal file
407
apiserver/bll/project/project_queries.py
Normal file
@@ -0,0 +1,407 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectQueries:
|
||||
def __init__(self, redis=None):
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
"""
|
||||
If passed projects is None means top level projects
|
||||
If passed projects is empty means no project filtering
|
||||
"""
|
||||
if include_subprojects:
|
||||
if not project_ids:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
if project_ids is None:
|
||||
project_ids = [None]
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
return {"project": {"$in": project_ids}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
if allow_public:
|
||||
return {"company": {"$in": [None, "", company_id]}}
|
||||
|
||||
return {"company": company_id}
|
||||
|
||||
@classmethod
|
||||
def get_aggregated_project_parameters(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "section"))
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "name"))
|
||||
),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
ParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_param_values(
|
||||
self, key: str, last_update: datetime, allowed_delta_sec=0
|
||||
) -> Optional[ParamValues]:
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving params cached values: {str(ex)}")
|
||||
|
||||
def get_task_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
pattern: str = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> ParamValues:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = "_".join(
|
||||
str(part)
|
||||
for part in (
|
||||
"hyperparam_values",
|
||||
company_id,
|
||||
"_".join(project_ids),
|
||||
section,
|
||||
name,
|
||||
allow_public,
|
||||
pattern,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
)
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
last_update=last_update,
|
||||
allowed_delta_sec=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
),
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
match_condition = {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
if pattern:
|
||||
match_condition["$expr"] = {
|
||||
"$regexMatch": {
|
||||
"input": f"${key_path}.value",
|
||||
"regex": pattern,
|
||||
"options": "i",
|
||||
}
|
||||
}
|
||||
|
||||
pipeline = [
|
||||
{"$match": match_condition},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
ids: Sequence[str],
|
||||
model_metrics: bool = False,
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
**({"_id": {"$in": ids}} if ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
entity_cls = Model if model_metrics else Task
|
||||
result = entity_cls.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
def get_model_metadata_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
key: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> ParamValues:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
|
||||
last_updated_model = (
|
||||
Model.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_model:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}_{page}_{page_size}"
|
||||
last_update = last_updated_model.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
198
apiserver/bll/project/sub_projects.py
Normal file
198
apiserver/bll/project/sub_projects.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import itertools
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Optional, Sequence, Mapping
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
|
||||
name_separator = "/"
|
||||
|
||||
|
||||
def _get_project_depth(project_name: str) -> int:
|
||||
return len(list(filter(None, project_name.split(name_separator))))
|
||||
|
||||
|
||||
def _validate_project_name(project_name: str, raise_if_empty=True) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove redundant '/' characters. Ensure that the project name is not empty
|
||||
Return the cleaned up project name and location
|
||||
"""
|
||||
name_parts = [p.strip() for p in project_name.split(name_separator) if p]
|
||||
if not name_parts:
|
||||
if raise_if_empty:
|
||||
raise errors.bad_request.InvalidProjectName(name=project_name)
|
||||
return "", ""
|
||||
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(
|
||||
company: str, user: str, name: str, creation_params: dict = None
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
Return the project
|
||||
"""
|
||||
name, location = _validate_project_name(name, raise_if_empty=False)
|
||||
if not name:
|
||||
return None
|
||||
|
||||
project = _get_writable_project_from_name(company, name)
|
||||
if project:
|
||||
return project
|
||||
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
company=company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
basename=name.split("/")[-1],
|
||||
**(creation_params or dict(description="")),
|
||||
)
|
||||
parent = _ensure_project(company, user, location, creation_params=creation_params)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def _save_under_parent(project: Project, parent: Optional[Project]):
|
||||
"""
|
||||
Save the project under the given parent project or top level (parent=None)
|
||||
Check that the project location matches the parent name
|
||||
"""
|
||||
location, _, _ = project.name.rpartition(name_separator)
|
||||
if not parent:
|
||||
if location:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match empty parent name"
|
||||
)
|
||||
project.parent = None
|
||||
project.path = []
|
||||
project.save()
|
||||
return
|
||||
|
||||
if location != parent.name:
|
||||
raise ValueError(
|
||||
f"Project location {location} does not match parent name {parent.name}"
|
||||
)
|
||||
project.parent = parent.id
|
||||
project.path = [*(parent.path or []), parent.id]
|
||||
project.save()
|
||||
|
||||
|
||||
def _get_writable_project_from_name(
|
||||
company,
|
||||
name,
|
||||
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Return a project from name. If the project not found then return None
|
||||
"""
|
||||
qs = Project.objects(company=company, name=name)
|
||||
if _only:
|
||||
qs = qs.only(*_only)
|
||||
return qs.first()
|
||||
|
||||
|
||||
ProjectsChildren = Mapping[str, Sequence[Project]]
|
||||
|
||||
|
||||
def _get_sub_projects(
|
||||
project_ids: Sequence[str],
|
||||
_only: Sequence[str] = ("id", "path"),
|
||||
search_hidden=True,
|
||||
allowed_ids: Sequence[str] = None,
|
||||
) -> ProjectsChildren:
|
||||
"""
|
||||
Return the list of child projects of all the levels for the parent project ids
|
||||
"""
|
||||
query = dict(path__in=project_ids)
|
||||
if not search_hidden:
|
||||
query["system_tags__nin"] = [EntityVisibility.hidden.value]
|
||||
if allowed_ids:
|
||||
query["id__in"] = allowed_ids
|
||||
|
||||
qs = Project.objects(**query)
|
||||
if _only:
|
||||
_only = set(_only) | {"path"}
|
||||
qs = qs.only(*_only)
|
||||
subprojects = list(qs)
|
||||
|
||||
return {
|
||||
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
|
||||
}
|
||||
|
||||
|
||||
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with all the parent projects
|
||||
"""
|
||||
projects = Project.objects(id__in=project_ids).only("id", "path")
|
||||
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
|
||||
return list({*(p.id for p in projects), *parent_ids})
|
||||
|
||||
|
||||
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
|
||||
"""
|
||||
Return project ids with the ids of all the subprojects
|
||||
"""
|
||||
children_ids = Project.objects(path__in=project_ids).scalar("id")
|
||||
return list({*project_ids, *children_ids})
|
||||
|
||||
|
||||
def _update_subproject_names(
|
||||
project: Project,
|
||||
children: Sequence[Project],
|
||||
old_name: str,
|
||||
update_path: bool = False,
|
||||
old_path: Sequence[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Update sub project names when the base project name changes
|
||||
Optionally update the paths
|
||||
"""
|
||||
updated = 0
|
||||
now = datetime.utcnow()
|
||||
for child in children:
|
||||
child_suffix = name_separator.join(
|
||||
child.name.split(name_separator)[len(old_name.split(name_separator)):]
|
||||
)
|
||||
updates = {
|
||||
"name": name_separator.join((project.name, child_suffix)),
|
||||
"last_update": now,
|
||||
}
|
||||
if update_path:
|
||||
updates["path"] = project.path + child.path[len(old_path):]
|
||||
updated += child.update(upsert=False, **updates)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def _reposition_project_with_children(
|
||||
project: Project, children: Sequence[Project], parent: Project
|
||||
) -> int:
|
||||
new_location = parent.name if parent else None
|
||||
old_name = project.name
|
||||
old_path = project.path
|
||||
project.name = name_separator.join(
|
||||
filter(None, (new_location, project.name.split(name_separator)[-1]))
|
||||
)
|
||||
project.last_update = datetime.utcnow()
|
||||
_save_under_parent(project, parent=parent)
|
||||
|
||||
moved = 1 + _update_subproject_names(
|
||||
project=project,
|
||||
children=children,
|
||||
old_name=old_name,
|
||||
update_path=True,
|
||||
old_path=old_path,
|
||||
)
|
||||
return moved
|
||||
@@ -1,10 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Callable, Sequence, Optional, Tuple
|
||||
from typing import Sequence, Optional, Tuple, Union
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue.queue_metrics import QueueMetrics
|
||||
@@ -14,6 +16,8 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
|
||||
log = config.logger(__file__)
|
||||
MOVE_FIRST = "first"
|
||||
MOVE_LAST = "last"
|
||||
|
||||
|
||||
class QueueBLL(object):
|
||||
@@ -32,6 +36,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
@@ -43,13 +48,31 @@ class QueueBLL(object):
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
system_tags=system_tags or [],
|
||||
metadata=metadata,
|
||||
last_update=now,
|
||||
)
|
||||
queue.save()
|
||||
return queue
|
||||
|
||||
def get_by_name(
|
||||
self, company_id: str, queue_name: str, only: Optional[Sequence[str]] = None,
|
||||
) -> Queue:
|
||||
qs = Queue.objects(name=queue_name, company=company_id)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
|
||||
return qs.first()
|
||||
|
||||
@staticmethod
|
||||
def _get_task_entries_projection(max_task_entries: int) -> dict:
|
||||
return dict(slice__entries=max_task_entries)
|
||||
|
||||
def get_by_id(
|
||||
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
only: Optional[Sequence[str]] = None,
|
||||
max_task_entries: int = None,
|
||||
) -> Queue:
|
||||
"""
|
||||
Get queue by id
|
||||
@@ -60,6 +83,8 @@ class QueueBLL(object):
|
||||
qs = Queue.objects(**query)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
if max_task_entries:
|
||||
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
|
||||
queue = qs.first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
@@ -110,7 +135,7 @@ class QueueBLL(object):
|
||||
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return Queue.safe_update(company_id, queue_id, update_fields)
|
||||
|
||||
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
|
||||
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
|
||||
"""
|
||||
Delete the queue
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
@@ -118,20 +143,80 @@ class QueueBLL(object):
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if queue.entries and not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
if queue.entries:
|
||||
if not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
from apiserver.bll.task import ChangeStatusRequest
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=[
|
||||
"id",
|
||||
"company",
|
||||
"status",
|
||||
"enqueue_status",
|
||||
"project",
|
||||
],
|
||||
)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason="Queue was deleted",
|
||||
status_message="",
|
||||
user_id=user_id,
|
||||
force=True,
|
||||
).execute(enqueue_status=None)
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Failed dequeuing task {item.task} from queue: {queue_id}"
|
||||
)
|
||||
|
||||
queue.delete()
|
||||
|
||||
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id, parameters=query_dict, query_dict=query_dict
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
|
||||
for worker in self.worker_bll.get_all(company_id):
|
||||
if queue_id in worker.queues:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
@@ -140,7 +225,12 @@ class QueueBLL(object):
|
||||
res = Queue.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
override_projection=projection,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
@@ -153,6 +243,7 @@ class QueueBLL(object):
|
||||
{
|
||||
"name": w.id,
|
||||
"ip": w.ip,
|
||||
"key": w.key,
|
||||
"task": w.task.to_struct() if w.task else None,
|
||||
}
|
||||
for w in queue_workers.get(item["id"], [])
|
||||
@@ -171,13 +262,15 @@ class QueueBLL(object):
|
||||
if any(e.task == task_id for e in queue.entries):
|
||||
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
entry = Entry(added=datetime.utcnow(), task=task_id)
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
push__entries=entry, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
@@ -185,16 +278,22 @@ class QueueBLL(object):
|
||||
|
||||
return res
|
||||
|
||||
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
|
||||
def get_next_task(
|
||||
self, company_id: str, queue_id: str, task_id: str = None
|
||||
) -> Optional[Entry]:
|
||||
"""
|
||||
Atomically pop and return the first task from the queue (or None)
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
|
||||
queue = Queue.objects(
|
||||
**query, **({"entries__0__task": task_id} if task_id else {})
|
||||
).modify(pop__entries=-1, upsert=False)
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
if not task_id or not Queue.objects(**query).first():
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
return
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
@@ -217,7 +316,6 @@ class QueueBLL(object):
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
entries_to_remove = [e for e in queue.entries if e.task == task_id]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
@@ -225,46 +323,153 @@ class QueueBLL(object):
|
||||
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
return len(entries_to_remove) if res else 0
|
||||
|
||||
def reposition_task(
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
task_id: str,
|
||||
pos_func: Callable[[int], int],
|
||||
self, company_id: str, queue_id: str, task_id: str, move_count: Union[int, str],
|
||||
) -> int:
|
||||
"""
|
||||
Moves the task in the queue to the position calculated by pos_func
|
||||
Returns the updated task position in the queue
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_queue_with_task(
|
||||
|
||||
def get_queue_and_task_position():
|
||||
q = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
return q, next(i for i, e in enumerate(q.entries) if e.task == task_id)
|
||||
|
||||
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
|
||||
new_position = pos_func(position)
|
||||
with translate_errors_context():
|
||||
queue, position = get_queue_and_task_position()
|
||||
if move_count == MOVE_FIRST:
|
||||
new_position = 0
|
||||
elif move_count == MOVE_LAST:
|
||||
new_position = len(queue.entries) - 1
|
||||
else:
|
||||
new_position = position + move_count
|
||||
if new_position == position:
|
||||
return new_position
|
||||
|
||||
if new_position != position:
|
||||
entry = queue.entries[position]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
updated = Queue.objects(entries__task=task_id, **query).update_one(
|
||||
pull__entries=entry, last_update=datetime.utcnow()
|
||||
)
|
||||
if not updated:
|
||||
raise errors.bad_request.RemovedDuringReposition(
|
||||
task=task_id, **query
|
||||
)
|
||||
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
|
||||
if new_position >= 0:
|
||||
inst["$push"]["entries"]["$position"] = new_position
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
__raw__=inst
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.FailedAddingDuringReposition(
|
||||
task=task_id, **query
|
||||
)
|
||||
without_entry = {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$ne": ["$$entry.task", task_id]},
|
||||
}
|
||||
}
|
||||
task_entry = {
|
||||
"$filter": {
|
||||
"input": "$entries",
|
||||
"as": "entry",
|
||||
"cond": {"$eq": ["$$entry.task", task_id]},
|
||||
}
|
||||
}
|
||||
if move_count == MOVE_FIRST:
|
||||
operations = [
|
||||
{
|
||||
"$set": {
|
||||
"entries": {"$concatArrays": [task_entry, without_entry]}
|
||||
}
|
||||
}
|
||||
]
|
||||
elif move_count == MOVE_LAST:
|
||||
operations = [
|
||||
{
|
||||
"$set": {
|
||||
"entries": {"$concatArrays": [without_entry, task_entry]}
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
operations = [
|
||||
{
|
||||
"$set": {
|
||||
"new_pos": {
|
||||
"$add": [
|
||||
{"$indexOfArray": ["$entries.task", task_id]},
|
||||
move_count,
|
||||
]
|
||||
},
|
||||
"without_entry": without_entry,
|
||||
"task_entry": task_entry,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"entries": {
|
||||
"$switch": {
|
||||
"branches": [
|
||||
{
|
||||
"case": {"$lte": ["$new_pos", 0]},
|
||||
"then": {
|
||||
"$concatArrays": [
|
||||
"$task_entry",
|
||||
"$without_entry",
|
||||
]
|
||||
},
|
||||
},
|
||||
{
|
||||
"case": {
|
||||
"$gte": [
|
||||
"$new_pos",
|
||||
{"$size": "$without_entry"},
|
||||
]
|
||||
},
|
||||
"then": {
|
||||
"$concatArrays": [
|
||||
"$without_entry",
|
||||
"$task_entry",
|
||||
]
|
||||
},
|
||||
},
|
||||
],
|
||||
"default": {
|
||||
"$concatArrays": [
|
||||
{"$slice": ["$without_entry", "$new_pos"]},
|
||||
"$task_entry",
|
||||
{
|
||||
"$slice": [
|
||||
"$without_entry",
|
||||
"$new_pos",
|
||||
{"$size": "$without_entry"},
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$unset": ["new_pos", "without_entry", "task_entry"]},
|
||||
]
|
||||
|
||||
return new_position
|
||||
updated = Queue.objects(
|
||||
id=queue_id, company=company_id, entries__task=task_id
|
||||
).update_one(__raw__=operations)
|
||||
|
||||
if not updated:
|
||||
raise errors.bad_request.FailedAddingDuringReposition(task=task_id)
|
||||
|
||||
return get_queue_and_task_position()[1]
|
||||
|
||||
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
|
||||
res = next(
|
||||
Queue.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": queue_id,
|
||||
}
|
||||
},
|
||||
{"$project": {"count": {"$size": "$entries"}}},
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if res is None:
|
||||
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
|
||||
return int(res.get("count"))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import sleep
|
||||
from typing import Sequence
|
||||
|
||||
import elasticsearch.helpers
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
@@ -11,25 +13,30 @@ from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
_conf = config.get("services.queues")
|
||||
_queue_metrics_key_pattern = "queue_metrics_{queue}"
|
||||
redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
class EsKeys:
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
QUEUE_FIELD = "queue"
|
||||
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
QUEUE_FIELD = "queue"
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def _queue_metrics_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"queue_metrics_{company_id}_"
|
||||
return f"queue_metrics_{company_id.lower()}_"
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
@@ -49,7 +56,7 @@ class QueueMetrics:
|
||||
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
|
||||
return total_waiting_in_secs / len(entries)
|
||||
|
||||
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
|
||||
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> int:
|
||||
"""
|
||||
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
|
||||
:return: True if the write to es was successful, false otherwise
|
||||
@@ -63,23 +70,22 @@ class QueueMetrics:
|
||||
|
||||
def make_doc(queue: Queue) -> dict:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
|
||||
entries
|
||||
),
|
||||
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
|
||||
},
|
||||
)
|
||||
return {
|
||||
EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
EsKeys.QUEUE_FIELD: queue.id,
|
||||
EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(entries),
|
||||
EsKeys.QUEUE_LENGTH_FIELD: len(entries),
|
||||
}
|
||||
|
||||
actions = list(map(make_doc, queues))
|
||||
logged = 0
|
||||
for q in queues:
|
||||
queue_doc = make_doc(q)
|
||||
self.es.index(index=es_index, document=queue_doc)
|
||||
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
|
||||
redis.set(redis_key, json.dumps(queue_doc))
|
||||
logged += 1
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
return logged
|
||||
|
||||
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
|
||||
query = dict(company=company_id)
|
||||
@@ -90,8 +96,7 @@ class QueueMetrics:
|
||||
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
body=es_req,
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -105,13 +110,13 @@ class QueueMetrics:
|
||||
return {
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"field": EsKeys.TIMESTAMP_FIELD,
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
"queues": {
|
||||
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
|
||||
"terms": {"field": EsKeys.QUEUE_FIELD},
|
||||
"aggs": cls._get_top_waiting_agg(),
|
||||
}
|
||||
},
|
||||
@@ -128,13 +133,13 @@ class QueueMetrics:
|
||||
"top_avg_waiting": {
|
||||
"top_hits": {
|
||||
"sort": [
|
||||
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
|
||||
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
|
||||
{EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
|
||||
{EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
|
||||
],
|
||||
"_source": {
|
||||
"includes": [
|
||||
cls.EsKeys.WAITING_TIME_FIELD,
|
||||
cls.EsKeys.QUEUE_LENGTH_FIELD,
|
||||
EsKeys.WAITING_TIME_FIELD,
|
||||
EsKeys.QUEUE_LENGTH_FIELD,
|
||||
]
|
||||
},
|
||||
"size": 1,
|
||||
@@ -149,6 +154,7 @@ class QueueMetrics:
|
||||
to_date: float,
|
||||
interval: int,
|
||||
queue_ids: Sequence[str],
|
||||
refresh: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the company queue metrics in the specified time range.
|
||||
@@ -158,7 +164,8 @@ class QueueMetrics:
|
||||
In case no queue ids are specified the avg across all the
|
||||
company queues is calculated for each metric
|
||||
"""
|
||||
# self._log_current_metrics(company, queue_ids=queue_ids)
|
||||
if refresh:
|
||||
self._log_current_metrics(company_id, queue_ids=queue_ids)
|
||||
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
@@ -174,7 +181,7 @@ class QueueMetrics:
|
||||
"aggs": self._get_dates_agg(interval),
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
|
||||
with translate_errors_context():
|
||||
res = self._search_company_metrics(company_id, es_req)
|
||||
|
||||
if "aggregations" not in res:
|
||||
@@ -256,7 +263,52 @@ class QueueMetrics:
|
||||
continue
|
||||
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
|
||||
queue_metrics[queue_data["key"]] = {
|
||||
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
|
||||
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
|
||||
"queue_length": res[EsKeys.QUEUE_LENGTH_FIELD],
|
||||
"avg_waiting_time": res[EsKeys.WAITING_TIME_FIELD],
|
||||
}
|
||||
return queue_metrics
|
||||
|
||||
|
||||
class MetricsRefresher:
|
||||
threads = ThreadsManager()
|
||||
|
||||
@classproperty
|
||||
def watch_interval_sec(self):
|
||||
return _conf.get("metrics_refresh_interval_sec", 300)
|
||||
|
||||
@classmethod
|
||||
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
|
||||
def start(cls, queue_metrics: QueueMetrics = None):
|
||||
if not cls.watch_interval_sec:
|
||||
return
|
||||
|
||||
if not queue_metrics:
|
||||
from .queue_bll import QueueBLL
|
||||
|
||||
queue_metrics = QueueBLL().metrics
|
||||
|
||||
sleep(10)
|
||||
while True:
|
||||
try:
|
||||
for queue in Queue.objects():
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
doc_time = 0
|
||||
try:
|
||||
redis_key = _queue_metrics_key_pattern.format(queue=queue.id)
|
||||
data = redis.get(redis_key)
|
||||
if data:
|
||||
queue_doc = json.loads(data)
|
||||
doc_time = int(queue_doc.get(EsKeys.TIMESTAMP_FIELD))
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Error reading queue metrics data for queue {queue.id}: {str(ex)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not doc_time
|
||||
or (timestamp - doc_time) > cls.watch_interval_sec * 1000
|
||||
):
|
||||
queue_metrics.log_queue_metrics_to_es(queue.company, [queue])
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed collecting queue metrics: {str(ex)}")
|
||||
sleep(60)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -31,24 +30,36 @@ class RedisCacheManager(Generic[T]):
|
||||
|
||||
def set_state(self, state: T) -> None:
|
||||
redis_key = self._get_redis_key(state.id)
|
||||
with TimingContext("redis", "cache_set_state"):
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
|
||||
def get_state(self, state_id) -> Optional[T]:
|
||||
redis_key = self._get_redis_key(state_id)
|
||||
with TimingContext("redis", "cache_get_state"):
|
||||
response = self.redis.get(redis_key)
|
||||
response = self.redis.get(redis_key)
|
||||
if response:
|
||||
return self.state_class.from_json(response)
|
||||
|
||||
def delete_state(self, state_id) -> None:
|
||||
with TimingContext("redis", "cache_delete_state"):
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
def get_or_create_state_core(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
) -> T:
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
return state
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
@@ -66,12 +77,9 @@ class RedisCacheManager(Generic[T]):
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
state = self.get_or_create_state_core(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
)
|
||||
|
||||
try:
|
||||
yield state
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from threading import Thread, Lock
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
|
||||
import attr
|
||||
@@ -9,76 +9,83 @@ import psutil
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
|
||||
class ResourceMonitor(Thread):
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
stat_threads = ThreadsManager("Statistics")
|
||||
|
||||
@classmethod
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
def __init__(self, sample_interval_sec=5):
|
||||
super(ResourceMonitor, self).__init__(daemon=True)
|
||||
self.sample_interval_sec = sample_interval_sec
|
||||
self._lock = Lock()
|
||||
self._clear()
|
||||
|
||||
def _clear(self):
|
||||
sample = self._get_sample()
|
||||
self._avg = sample
|
||||
self._min = sample
|
||||
self._max = sample
|
||||
self._clear_time = datetime.utcnow()
|
||||
self._count = 1
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
|
||||
@classmethod
|
||||
def _get_sample(cls) -> Sample:
|
||||
return cls.Sample(
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_current_sample(cls) -> "Sample":
|
||||
return cls(
|
||||
cpu_usage=psutil.cpu_percent(),
|
||||
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
|
||||
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
|
||||
)
|
||||
|
||||
def run(self):
|
||||
while not ThreadsManager.terminating:
|
||||
sleep(self.sample_interval_sec)
|
||||
|
||||
sample = self._get_sample()
|
||||
class ResourceMonitor:
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
sample = Sample.get_current_sample()
|
||||
self.avg = sample
|
||||
self.min = sample
|
||||
self.max = sample
|
||||
self.time = datetime.utcnow()
|
||||
self.count = 1
|
||||
|
||||
with self._lock:
|
||||
self._min = self._min.min(sample)
|
||||
self._max = self._max.max(sample)
|
||||
self._avg = self._avg.avg(sample, self._count)
|
||||
self._count += 1
|
||||
def add_sample(self, sample: Sample):
|
||||
self.min = self.min.min(sample)
|
||||
self.max = self.max.max(sample)
|
||||
self.avg = self.avg.avg(sample, self.count)
|
||||
self.count += 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
sample_interval_sec = 5
|
||||
_lock = Lock()
|
||||
accumulator = Accumulator()
|
||||
|
||||
@classmethod
|
||||
@stat_threads.register("resource_monitor", daemon=True)
|
||||
def start(cls):
|
||||
while True:
|
||||
sleep(cls.sample_interval_sec)
|
||||
sample = Sample.get_current_sample()
|
||||
with cls._lock:
|
||||
cls.accumulator.add_sample(sample)
|
||||
|
||||
@classmethod
|
||||
def get_stats(cls) -> dict:
|
||||
""" Returns current resource statistics and clears internal resource statistics """
|
||||
with self._lock:
|
||||
min_ = attr.asdict(self._min)
|
||||
max_ = attr.asdict(self._max)
|
||||
avg = attr.asdict(self._avg)
|
||||
interval = datetime.utcnow() - self._clear_time
|
||||
self._clear()
|
||||
with cls._lock:
|
||||
min_ = attr.asdict(cls.accumulator.min)
|
||||
max_ = attr.asdict(cls.accumulator.max)
|
||||
avg = attr.asdict(cls.accumulator.avg)
|
||||
interval = datetime.utcnow() - cls.accumulator.time
|
||||
cls.accumulator = cls.Accumulator()
|
||||
|
||||
return {
|
||||
"interval_sec": interval.total_seconds(),
|
||||
|
||||
@@ -8,8 +8,7 @@ from typing import Sequence, Optional
|
||||
|
||||
import dpath
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.bll.util import get_server_uuid
|
||||
@@ -21,9 +20,8 @@ from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.json import dumps
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
from apiserver.version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor
|
||||
from .resource_monitor import ResourceMonitor, stat_threads
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -31,21 +29,23 @@ worker_bll = WorkerBLL()
|
||||
|
||||
|
||||
class StatisticsReporter:
|
||||
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
|
||||
send_queue = queue.Queue()
|
||||
supported = config.get("apiserver.statistics.supported", True)
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
ResourceMonitor.start()
|
||||
cls.start_sender()
|
||||
cls.start_reporter()
|
||||
|
||||
@classmethod
|
||||
@threads.register("reporter", daemon=True)
|
||||
@stat_threads.register("reporter", daemon=True)
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
Note: in trains we usually have only a single company
|
||||
Note: in clearml we usually have only a single company
|
||||
"""
|
||||
if not cls.supported:
|
||||
return
|
||||
@@ -54,7 +54,7 @@ class StatisticsReporter:
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
)
|
||||
sleep(report_interval.total_seconds())
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
for company in Company.objects(
|
||||
defaults__stats_option__enabled=True
|
||||
@@ -68,7 +68,7 @@ class StatisticsReporter:
|
||||
sleep(report_interval.total_seconds())
|
||||
|
||||
@classmethod
|
||||
@threads.register("sender", daemon=True)
|
||||
@stat_threads.register("sender", daemon=True)
|
||||
def start_sender(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
@@ -85,7 +85,7 @@ class StatisticsReporter:
|
||||
|
||||
WarningFilter.attach()
|
||||
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
report = cls.send_queue.get()
|
||||
|
||||
@@ -111,7 +111,7 @@ class StatisticsReporter:
|
||||
"uuid": get_server_uuid(),
|
||||
"queues": {"count": Queue.objects(company=company_id).count()},
|
||||
"users": {"count": User.objects(company=company_id).count()},
|
||||
"resources": cls.threads.resource_monitor.get_stats(),
|
||||
"resources": ResourceMonitor.get_stats(),
|
||||
"experiments": next(
|
||||
iter(cls._get_experiments_stats(company_id).values()), {}
|
||||
),
|
||||
@@ -254,6 +254,14 @@ class StatisticsReporter:
|
||||
**({"last_worker": {"$in": workers}} if workers else {}),
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"last_worker": 1,
|
||||
"last_update": 1,
|
||||
"started": 1,
|
||||
"last_iteration": 1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$last_worker" if workers else None,
|
||||
|
||||
48
apiserver/bll/storage/__init__.py
Normal file
48
apiserver/bll/storage/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from copy import copy
|
||||
|
||||
from boltons.cacheutils import cachedproperty
|
||||
from clearml.backend_config.bucket_config import (
|
||||
S3BucketConfigurations,
|
||||
AzureContainerConfigurations,
|
||||
GSBucketConfigurations,
|
||||
)
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class StorageBLL:
|
||||
default_aws_configs: S3BucketConfigurations = None
|
||||
conf = config.get("services.storage_credentials")
|
||||
|
||||
@cachedproperty
|
||||
def _default_aws_configs(self) -> S3BucketConfigurations:
|
||||
return S3BucketConfigurations.from_config(self.conf.get("aws.s3"))
|
||||
|
||||
@cachedproperty
|
||||
def _default_azure_configs(self) -> AzureContainerConfigurations:
|
||||
return AzureContainerConfigurations.from_config(self.conf.get("azure.storage"))
|
||||
|
||||
@cachedproperty
|
||||
def _default_gs_configs(self) -> GSBucketConfigurations:
|
||||
return GSBucketConfigurations.from_config(self.conf.get("google.storage"))
|
||||
|
||||
def get_azure_settings_for_company(
|
||||
self,
|
||||
company_id: str,
|
||||
) -> AzureContainerConfigurations:
|
||||
return copy(self._default_azure_configs)
|
||||
|
||||
def get_gs_settings_for_company(
|
||||
self,
|
||||
company_id: str,
|
||||
) -> GSBucketConfigurations:
|
||||
return copy(self._default_gs_configs)
|
||||
|
||||
def get_aws_settings_for_company(
|
||||
self,
|
||||
company_id: str,
|
||||
) -> S3BucketConfigurations:
|
||||
return copy(self._default_aws_configs)
|
||||
@@ -1,7 +1,5 @@
|
||||
from .task_bll import TaskBLL
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from hashlib import md5
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
@@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
|
||||
Calculate id from 'key' and 'mode' fields
|
||||
Return hash on on the id so that it will not contain mongo illegal characters
|
||||
"""
|
||||
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
||||
key_hash: str = hash_field_name(artifact["key"])
|
||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||
return f"{key_hash}_{mode}"
|
||||
|
||||
@@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
|
||||
nested_set(
|
||||
fields,
|
||||
artifacts_field,
|
||||
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
||||
value=sorted(artifacts.values(), key=itemgetter("key")),
|
||||
)
|
||||
|
||||
|
||||
@@ -49,49 +49,45 @@ class Artifacts:
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "update_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -15,7 +15,7 @@ from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@@ -32,7 +32,10 @@ class HyperParams:
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -64,78 +67,84 @@ class HyperParams:
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
return update_task(
|
||||
task,
|
||||
user_id=identity.user,
|
||||
update_cmds=delete_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{mongoengine_safe(section)}"
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
return update_task(
|
||||
task,
|
||||
user_id=identity.user,
|
||||
update_cmds=update_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
@@ -160,7 +169,10 @@ class HyperParams:
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -175,71 +187,76 @@ class HyperParams:
|
||||
|
||||
@classmethod
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
|
||||
) -> Dict[str, list]:
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
for task in tasks
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
*([skip_empty_condition] if skip_empty else []),
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
|
||||
cls,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[str],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apiserver.bll.task import update_project_time
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import TaskStatus, Task
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
@@ -39,7 +39,7 @@ class NonResponsiveTasksWatchdog:
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
@@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by="__apiserver__",
|
||||
)
|
||||
if updated:
|
||||
project_ids.add(task.project)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import dpath
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
|
||||
(("execution", "model_desc"), "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
legacy_params = nested_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
not fields.get(new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
@@ -117,21 +118,34 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
nested_set(fields, new_path, new_param)
|
||||
nested_delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
ParameterKeyEscaper.escape(k): v for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
for key, value in params.items()
|
||||
def ensure_non_empty(k: str, desc: str) -> str:
|
||||
if not k:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Empty {desc} name is not allowed"
|
||||
)
|
||||
return k
|
||||
|
||||
params = fields.get("hyperparams")
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(key, "section")): {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(k, "parameter")): v
|
||||
for k, v in value.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields["hyperparams"] = escaped_params
|
||||
|
||||
params = fields.get("configuration")
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(ensure_non_empty(key, "configuration")): value
|
||||
for key, value in params.items()
|
||||
}
|
||||
fields["configuration"] = escaped_params
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
@@ -140,7 +154,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
@@ -150,18 +164,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
fields[param_field] = unescaped_params
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
("hyperparams", ("execution", "parameters"), True),
|
||||
("configuration", ("execution", "model_desc"), False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
fields.get(new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
nested_set(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
@@ -174,7 +188,7 @@ def _process_path(path: str):
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
if len(parts) < 2 or len(parts) > 4:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
@@ -184,7 +198,8 @@ def _process_path(path: str):
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
("execution.model_desc", "configuration"),
|
||||
("execution.docker_cmd", "container"),
|
||||
):
|
||||
path: str
|
||||
paths = [path.replace(old_prefix, new_prefix) for path in paths]
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
from typing import Collection, Sequence, Tuple, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
from six import string_types
|
||||
|
||||
import apiserver.database.utils as dbutils
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -21,21 +22,30 @@ from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
external_task_types,
|
||||
ModelItem,
|
||||
Models,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelNames,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
id as create_id,
|
||||
)
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
deleted_prefix,
|
||||
get_last_metric_updates,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@@ -44,48 +54,17 @@ project_bll = ProjectBLL()
|
||||
|
||||
|
||||
class TaskBLL:
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
"""
|
||||
Return the list of unique task types used by company and public tasks
|
||||
If project ids passed then only tasks from these projects are considered
|
||||
"""
|
||||
query = get_company_or_none_constraint(company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
with TimingContext("mongo", "task_with_access"):
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
only_fields=None,
|
||||
allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
@@ -94,15 +73,14 @@ class TaskBLL:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -117,7 +95,7 @@ class TaskBLL:
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
with translate_errors_context():
|
||||
ids = set(task_ids)
|
||||
q = Task.get_many(
|
||||
company=company_id,
|
||||
@@ -137,33 +115,34 @@ class TaskBLL:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
identity = call.identity
|
||||
def create(company: str, user: str, fields: dict):
|
||||
now = datetime.utcnow()
|
||||
return Task(
|
||||
id=create_id(),
|
||||
user=identity.user,
|
||||
company=identity.company,
|
||||
user=user,
|
||||
company=company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
**fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_execution_model(task, allow_only_public=False):
|
||||
if not task.execution or not task.execution.model:
|
||||
def validate_input_models(task, allow_only_public=False):
|
||||
if not task.models.input:
|
||||
return
|
||||
|
||||
company = None if allow_only_public else task.company
|
||||
model_id = task.execution.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(model=model_id)
|
||||
model_ids = set(m.model for m in task.models.input)
|
||||
models = Model.objects(
|
||||
Q(id__in=model_ids) & get_company_or_none_constraint(company)
|
||||
).only("id")
|
||||
missing = model_ids - {m.id for m in models}
|
||||
if missing:
|
||||
raise errors.bad_request.InvalidModelId(models=missing)
|
||||
|
||||
return model
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def clone_task(
|
||||
@@ -179,7 +158,9 @@ class TaskBLL:
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
container: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
@@ -195,10 +176,29 @@ class TaskBLL:
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
input_models = [
|
||||
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
|
||||
]
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
execution_model = execution_overrides.pop("model", None)
|
||||
if not input_models and execution_model:
|
||||
input_models = [
|
||||
ModelItem(
|
||||
model=execution_model,
|
||||
name=TaskModelNames[TaskModelTypes.input],
|
||||
updated=now,
|
||||
)
|
||||
]
|
||||
|
||||
docker_cmd = execution_overrides.pop("docker_cmd", None)
|
||||
if not container and docker_cmd:
|
||||
image, _, arguments = docker_cmd.partition(" ")
|
||||
container = {"image": image, "arguments": arguments}
|
||||
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
@@ -207,6 +207,8 @@ class TaskBLL:
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
|
||||
escape_dict_field(execution_overrides, "model_labels")
|
||||
|
||||
execution_dict.update(execution_overrides)
|
||||
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
@@ -216,7 +218,7 @@ class TaskBLL:
|
||||
execution_dict["artifacts"] = {
|
||||
k: a
|
||||
for k, a in artifacts.items()
|
||||
if a.get("mode") != ArtifactModes.output
|
||||
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
|
||||
}
|
||||
execution_dict.pop("queue", None)
|
||||
|
||||
@@ -227,12 +229,10 @@ class TaskBLL:
|
||||
project_name=new_project_name,
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
description="Auto-generated while cloning",
|
||||
description="",
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
|
||||
if not input_tags:
|
||||
return input_tags
|
||||
@@ -240,54 +240,70 @@ class TaskBLL:
|
||||
return [
|
||||
tag
|
||||
for tag in input_tags
|
||||
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
if tag
|
||||
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
def ensure_int_labels(execution: dict) -> dict:
|
||||
if not execution:
|
||||
return execution
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
model_labels = execution.get("model_labels")
|
||||
if model_labels:
|
||||
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
|
||||
|
||||
return execution
|
||||
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else task.id
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=ensure_int_labels(execution_dict),
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
projects=[new_task.project],
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
|
||||
return new_task, new_project_data
|
||||
|
||||
@@ -295,7 +311,7 @@ class TaskBLL:
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_models=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
@@ -307,6 +323,7 @@ class TaskBLL:
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not task.parent.startswith(deleted_prefix)
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
@@ -318,60 +335,21 @@ class TaskBLL:
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
last_update: datetime,
|
||||
**extra_updates,
|
||||
):
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
count = 0
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
@@ -381,21 +359,24 @@ class TaskBLL:
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
count += Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
last_changed_by=user_id,
|
||||
**updates,
|
||||
)
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
last_events: Dict[str, Dict[str, dict]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
@@ -420,25 +401,21 @@ class TaskBLL:
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_scalar_values is not None:
|
||||
|
||||
def op_path(op, *path):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
get_last_metric_updates(
|
||||
task_id=task_id,
|
||||
last_scalar_events=last_scalar_events,
|
||||
raw_updates=raw_updates,
|
||||
extra_updates=extra_updates,
|
||||
)
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
def events_per_type(metric_data_: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
return {
|
||||
event_type: EventStats(last_update=event["timestamp"])
|
||||
for event_type, event in metric_data.items()
|
||||
for event_type, event in metric_data_.items()
|
||||
}
|
||||
|
||||
metric_stats = {
|
||||
@@ -449,238 +426,55 @@ class TaskBLL:
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
TaskBLL.set_last_update(
|
||||
ret = TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
if ret and raw_updates:
|
||||
Task.objects(id=task_id).update_one(__raw__=[{"$set": raw_updates}])
|
||||
|
||||
@classmethod
|
||||
def model_set_ready(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
publish_task: bool,
|
||||
force_publish_task: bool = False,
|
||||
) -> tuple:
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
elif model.ready:
|
||||
raise errors.bad_request.ModelIsReady(**query)
|
||||
|
||||
published_task_data = {}
|
||||
if model.task and publish_task:
|
||||
task = (
|
||||
Task.objects(id=model.task, company=company_id)
|
||||
.only("id", "status")
|
||||
.first()
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
published_task_data["data"] = cls.publish_task(
|
||||
task_id=model.task,
|
||||
company_id=company_id,
|
||||
publish_model=False,
|
||||
force=force_publish_task,
|
||||
)
|
||||
published_task_data["id"] = model.task
|
||||
|
||||
updated = model.update(upsert=False, ready=True)
|
||||
return updated, published_task_data
|
||||
|
||||
@classmethod
|
||||
def publish_task(
|
||||
cls,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
publish_model: bool,
|
||||
force: bool,
|
||||
status_reason: str = "",
|
||||
status_message: str = "",
|
||||
) -> dict:
|
||||
task = cls.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
|
||||
previous_task_status = task.status
|
||||
output = task.output or Output()
|
||||
publish_failed = False
|
||||
|
||||
try:
|
||||
# set state to publishing
|
||||
task.status = TaskStatus.publishing
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.output.model and publish_model:
|
||||
output_model = (
|
||||
Model.objects(id=task.output.model)
|
||||
.only("id", "task", "ready")
|
||||
.first()
|
||||
)
|
||||
if output_model and not output_model.ready:
|
||||
cls.model_set_ready(
|
||||
model_id=task.output.model,
|
||||
company_id=company_id,
|
||||
publish_task=False,
|
||||
)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
publish_failed = True
|
||||
raise ex
|
||||
finally:
|
||||
if publish_failed:
|
||||
task.status = previous_task_status
|
||||
task.save()
|
||||
|
||||
@classmethod
|
||||
def stop_task(
|
||||
cls,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
execution_progress 'running', or force=True. Development task or
|
||||
task that has no associated worker is stopped immediately.
|
||||
For a non-development task with worker only the status message
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = cls.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
"""Checks if there is an active worker running the task"""
|
||||
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
|
||||
return (
|
||||
t.last_worker
|
||||
and t.last_update
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
new_status = task.status
|
||||
status_message = TaskStatusMessage.stopping
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=new_status,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
).execute()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"total": 1,
|
||||
"results": {"$slice": ["$results", page * page_size, page_size]},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
def remove_task_from_all_queues(company_id: str, task_id: str) -> int:
|
||||
return Queue.objects(company=company_id, entries__task=task_id).update(
|
||||
pull__entries__task=task_id, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
cls,
|
||||
task: Task,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
remove_from_all_queues=False,
|
||||
new_status=None,
|
||||
):
|
||||
cls.dequeue(task, company_id)
|
||||
try:
|
||||
cls.dequeue(task, company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the queue was deleted
|
||||
pass
|
||||
|
||||
if remove_from_all_queues:
|
||||
cls.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
|
||||
|
||||
if task.status not in [TaskStatus.queued, TaskStatus.in_progress]:
|
||||
return {"updated": 0}
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
new_status=new_status or task.enqueue_status or TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(unset__execution__queue=1)
|
||||
user_id=user_id,
|
||||
force=True,
|
||||
).execute(enqueue_status=None)
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
|
||||
|
||||
366
apiserver/bll/task/task_cleanup.py
Normal file
366
apiserver/bll/task/task_cleanup.py
Normal file
@@ -0,0 +1,366 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition, bucketize, first, chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import NotUniqueError
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_bll import PlotFields
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
|
||||
from apiserver.database.model.url_to_delete import (
|
||||
StorageType,
|
||||
UrlToDelete,
|
||||
FileType,
|
||||
DeletionStatus,
|
||||
)
|
||||
from apiserver.database.utils import id as db_id
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskUrls:
|
||||
model_urls: Sequence[str]
|
||||
event_urls: Sequence[str]
|
||||
artifact_urls: Sequence[str]
|
||||
|
||||
def __add__(self, other: "TaskUrls"):
|
||||
if not other:
|
||||
return self
|
||||
|
||||
return TaskUrls(
|
||||
model_urls=list(set(self.model_urls) | set(other.model_urls)),
|
||||
event_urls=list(set(self.event_urls) | set(other.event_urls)),
|
||||
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class CleanupResult:
|
||||
"""
|
||||
Counts of objects modified in task cleanup operation
|
||||
"""
|
||||
|
||||
updated_children: int
|
||||
updated_models: int
|
||||
deleted_models: int
|
||||
urls: TaskUrls = None
|
||||
|
||||
def __add__(self, other: "CleanupResult"):
|
||||
if not other:
|
||||
return self
|
||||
|
||||
return CleanupResult(
|
||||
updated_children=self.updated_children + other.updated_children,
|
||||
updated_models=self.updated_models + other.updated_models,
|
||||
deleted_models=self.deleted_models + other.deleted_models,
|
||||
urls=self.urls + other.urls if self.urls else other.urls,
|
||||
)
|
||||
|
||||
|
||||
def collect_plot_image_urls(
|
||||
company: str, task_or_model: Union[str, Sequence[str]]
|
||||
) -> Set[str]:
|
||||
urls = set()
|
||||
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
|
||||
for tasks in chunked_iter(task_ids, 100):
|
||||
next_scroll_id = None
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_ids=tasks, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def collect_debug_image_urls(
|
||||
company: str, task_or_model: Union[str, Sequence[str]]
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
"""
|
||||
urls = set()
|
||||
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
|
||||
for tasks in chunked_iter(task_ids, 100):
|
||||
after_key = None
|
||||
while True:
|
||||
res, after_key = event_bll.get_debug_image_urls(
|
||||
company_id=company,
|
||||
task_ids=tasks,
|
||||
after_key=after_key,
|
||||
)
|
||||
urls.update(res)
|
||||
if not after_key:
|
||||
break
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
supported_storage_types = {
|
||||
"s3://": StorageType.s3,
|
||||
"azure://": StorageType.azure,
|
||||
"gs://": StorageType.gs,
|
||||
}
|
||||
|
||||
supported_storage_types.update(
|
||||
{
|
||||
p: StorageType.fileserver
|
||||
for p in config.get(
|
||||
"services.async_urls_delete.fileserver.url_prefixes",
|
||||
["https://", "http://"],
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _schedule_for_delete(
|
||||
company: str,
|
||||
user: str,
|
||||
task_id: str,
|
||||
urls: Set[str],
|
||||
can_delete_folders: bool,
|
||||
) -> Set[str]:
|
||||
urls_per_storage = bucketize(
|
||||
urls,
|
||||
key=lambda u: first(
|
||||
type_
|
||||
for prefix, type_ in supported_storage_types.items()
|
||||
if u.startswith(prefix)
|
||||
),
|
||||
)
|
||||
urls_per_storage.pop(None, None)
|
||||
|
||||
processed_urls = set()
|
||||
for storage_type, storage_urls in urls_per_storage.items():
|
||||
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
|
||||
scheduled_to_delete = set()
|
||||
for url in storage_urls:
|
||||
folder = None
|
||||
if delete_folders:
|
||||
try:
|
||||
parsed = furl(url)
|
||||
if parsed.path and len(parsed.path.segments) > 1:
|
||||
folder = parsed.remove(
|
||||
args=True, fragment=True, path=parsed.path.segments[-1]
|
||||
).url.rstrip("/")
|
||||
except Exception as ex:
|
||||
pass
|
||||
|
||||
to_delete = folder or url
|
||||
if to_delete in scheduled_to_delete:
|
||||
processed_urls.add(url)
|
||||
continue
|
||||
|
||||
try:
|
||||
UrlToDelete(
|
||||
id=db_id(),
|
||||
company=company,
|
||||
user=user,
|
||||
url=to_delete,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
storage_type=storage_type,
|
||||
type=FileType.folder if folder else FileType.file,
|
||||
).save()
|
||||
except (DuplicateKeyError, NotUniqueError):
|
||||
existing = UrlToDelete.objects(company=company, url=to_delete).first()
|
||||
if existing:
|
||||
existing.update(
|
||||
user=user,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
retry_count=0,
|
||||
unset__last_failure_time=1,
|
||||
unset__last_failure_reason=1,
|
||||
status=DeletionStatus.created,
|
||||
)
|
||||
processed_urls.add(url)
|
||||
scheduled_to_delete.add(to_delete)
|
||||
|
||||
return processed_urls
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
company: str,
|
||||
user: str,
|
||||
task: Task,
|
||||
force: bool = False,
|
||||
update_children=True,
|
||||
return_file_urls=False,
|
||||
delete_output_models=True,
|
||||
delete_external_artifacts=True,
|
||||
) -> CleanupResult:
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
|
||||
task, force
|
||||
)
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", True
|
||||
)
|
||||
event_urls, artifact_urls, model_urls = set(), set(), set()
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls = collect_debug_image_urls(task.company, task.id)
|
||||
event_urls.update(collect_plot_image_urls(task.company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls = {
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
model_urls = {
|
||||
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
|
||||
}
|
||||
|
||||
deleted_task_id = f"{deleted_prefix}{task.id}"
|
||||
updated_children = 0
|
||||
now = datetime.utcnow()
|
||||
if update_children:
|
||||
updated_children = Task.objects(parent=task.id).update(
|
||||
parent=deleted_task_id,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
|
||||
deleted_models = 0
|
||||
updated_models = 0
|
||||
for models, allow_delete in ((draft_models, True), (published_models, False)):
|
||||
if not models:
|
||||
continue
|
||||
if delete_output_models and allow_delete:
|
||||
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
|
||||
if model_ids:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls.update(collect_debug_image_urls(task.company, model_ids))
|
||||
event_urls.update(collect_plot_image_urls(task.company, model_ids))
|
||||
|
||||
event_bll.delete_multi_task_events(
|
||||
task.company,
|
||||
model_ids,
|
||||
model=True,
|
||||
)
|
||||
deleted_models += Model.objects(id__in=model_ids).delete()
|
||||
|
||||
if in_use_model_ids:
|
||||
Model.objects(id__in=list(in_use_model_ids)).update(
|
||||
unset__task=1,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user,
|
||||
)
|
||||
continue
|
||||
|
||||
if update_children:
|
||||
updated_models += Model.objects(id__in=[m.id for m in models]).update(
|
||||
task=deleted_task_id,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
else:
|
||||
Model.objects(id__in=[m.id for m in models]).update(
|
||||
unset__task=1,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user,
|
||||
)
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=task.id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=not in_use_model_ids and not published_models,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
updated_children=updated_children,
|
||||
updated_models=updated_models,
|
||||
urls=TaskUrls(
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
model_urls=list(model_urls),
|
||||
)
|
||||
if return_file_urls
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(
|
||||
task, force: bool
|
||||
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
|
||||
if not force:
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
model_fields = ["id", "ready", "uri"]
|
||||
published_models, draft_models = partition(
|
||||
Model.objects(task=task.id).only(*model_fields),
|
||||
key=attrgetter("ready"),
|
||||
)
|
||||
if not force and published_models:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(published_models),
|
||||
)
|
||||
|
||||
if task.models and task.models.output:
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids).only(*model_fields):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=output_model.id,
|
||||
)
|
||||
published_models.append(output_model)
|
||||
else:
|
||||
draft_models.append(output_model)
|
||||
|
||||
in_use_model_ids = {}
|
||||
if draft_models:
|
||||
model_ids = {m.id for m in draft_models}
|
||||
dependent_tasks = Task.objects(models__input__model__in=list(model_ids)).only(
|
||||
"id", "models"
|
||||
)
|
||||
in_use_model_ids = model_ids & {
|
||||
m.model
|
||||
for m in chain.from_iterable(
|
||||
t.models.input for t in dependent_tasks if t.models
|
||||
)
|
||||
}
|
||||
|
||||
return published_models, draft_models, in_use_model_ids
|
||||
530
apiserver/bll/task/task_operations.py
Normal file
530
apiserver/bll/task/task_operations.py
Normal file
@@ -0,0 +1,530 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Any, Tuple, Union, Sequence
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
validate_status_change,
|
||||
ChangeStatusRequest,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
TaskStatus,
|
||||
Task,
|
||||
TaskSystemTags,
|
||||
TaskStatusMessage,
|
||||
ArtifactModes,
|
||||
Execution,
|
||||
DEFAULT_LAST_ITERATION,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def archive_task(
|
||||
task: Union[str, Task],
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Deque and archive task
|
||||
Return 1 if successful
|
||||
"""
|
||||
if isinstance(task, str):
|
||||
task = get_task_with_write_access(
|
||||
task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
"execution",
|
||||
"status",
|
||||
"project",
|
||||
"system_tags",
|
||||
"enqueue_status",
|
||||
),
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
remove_from_all_queues=True,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
"""
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id",),
|
||||
)
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=identity.user,
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
remove_from_all_queues: bool = False,
|
||||
new_status=None,
|
||||
) -> Tuple[int, dict]:
|
||||
if new_status and new_status not in get_options(TaskStatus):
|
||||
raise errors.bad_request.ValidationError(f"Invalid task status: {new_status}")
|
||||
|
||||
# get the task without write access to make sure that it actually exists
|
||||
task = Task.get(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=("id",),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
"execution",
|
||||
"status",
|
||||
"project",
|
||||
"enqueue_status",
|
||||
),
|
||||
)
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
remove_from_all_queues=remove_from_all_queues,
|
||||
new_status=new_status,
|
||||
)
|
||||
return 1, res
|
||||
|
||||
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
queue_name: str = None,
|
||||
validate: bool = False,
|
||||
force: bool = False,
|
||||
) -> Tuple[int, dict]:
|
||||
if queue_id and queue_name:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either queue id or queue name should be provided"
|
||||
)
|
||||
|
||||
if queue_name:
|
||||
queue = queue_bll.get_by_name(
|
||||
company_id=company_id, queue_name=queue_name, only=("id",)
|
||||
)
|
||||
if not queue:
|
||||
queue = queue_bll.create(company_id=company_id, name=queue_name)
|
||||
queue_id = queue.id
|
||||
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
if validate:
|
||||
TaskBLL.validate(task)
|
||||
|
||||
res = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.queued,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
|
||||
except Exception:
|
||||
# failed enqueueing, revert to previous state
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
current_status_override=TaskStatus.queued,
|
||||
new_status=task.status,
|
||||
force=True,
|
||||
status_reason="failed enqueueing",
|
||||
user_id=user_id,
|
||||
).execute(enqueue_status=None)
|
||||
raise
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
|
||||
|
||||
nested_set(res, ("fields", "execution.queue"), queue_id)
|
||||
return 1, res
|
||||
|
||||
|
||||
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
try:
|
||||
collection_name = Task._get_collection_name()
|
||||
trash_collection_name = f"{collection_name}__trash"
|
||||
Task.aggregate(
|
||||
[
|
||||
{"$match": {"_id": {"$in": tasks}}},
|
||||
{
|
||||
"$merge": {
|
||||
"into": trash_collection_name,
|
||||
"on": "_id",
|
||||
"whenMatched": "replace",
|
||||
"whenNotMatched": "insert",
|
||||
}
|
||||
},
|
||||
],
|
||||
allow_disk_use=True,
|
||||
)
|
||||
except Exception as ex:
|
||||
log.error(f"Error copying tasks to trash {str(ex)}")
|
||||
|
||||
return Task.objects(id__in=tasks).delete()
|
||||
|
||||
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if (
|
||||
task.status != TaskStatus.created
|
||||
and EntityVisibility.archived.value not in task.system_tags
|
||||
and not force
|
||||
):
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"due to status, use force=True",
|
||||
task=task.id,
|
||||
expected=TaskStatus.created,
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
remove_from_all_queues=True,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
if move_to_trash:
|
||||
# make sure that whatever changes were done to the task are saved
|
||||
# the task itself will be deleted later in the move_tasks_to_trash operation
|
||||
task.last_update = datetime.utcnow()
|
||||
task.save()
|
||||
else:
|
||||
task.delete()
|
||||
|
||||
update_project_time(task.project)
|
||||
return 1, task, cleanup_res
|
||||
|
||||
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||
|
||||
dequeued = {}
|
||||
updates = {}
|
||||
|
||||
try:
|
||||
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
TaskBLL.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
|
||||
|
||||
cleaned_up = cleanup_task(
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
update_children=False,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__unique_metrics=[],
|
||||
set__metric_stats={},
|
||||
set__models__output=[],
|
||||
set__runtime={},
|
||||
unset__output__result=1,
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
unset__started=1,
|
||||
unset__completed=1,
|
||||
unset__published=1,
|
||||
unset__active_duration=1,
|
||||
unset__enqueue_status=1,
|
||||
)
|
||||
|
||||
if clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(),
|
||||
unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
if task.execution and task.execution.artifacts:
|
||||
updates.update(
|
||||
set__execution__artifacts={
|
||||
key: artifact
|
||||
for key, artifact in task.execution.artifacts.items()
|
||||
if artifact.mode == ArtifactModes.input
|
||||
}
|
||||
)
|
||||
|
||||
res = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.created,
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
user_id=user_id,
|
||||
).execute(
|
||||
**updates,
|
||||
)
|
||||
|
||||
return dequeued, cleaned_up, res
|
||||
|
||||
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str, Identity], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
|
||||
previous_task_status = task.status
|
||||
output = task.output or Output()
|
||||
publish_failed = False
|
||||
|
||||
try:
|
||||
# set state to publishing
|
||||
task.status = TaskStatus.publishing
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.models and task.models.output and publish_model_func:
|
||||
model_id = task.models.output[-1].model
|
||||
model = (
|
||||
Model.objects(id=model_id, company=company_id)
|
||||
.only("id", "ready")
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id, identity)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
user_id=user_id,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
publish_failed = True
|
||||
raise ex
|
||||
finally:
|
||||
if publish_failed:
|
||||
task.status = previous_task_status
|
||||
task.save()
|
||||
|
||||
|
||||
def stop_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
execution_progress 'running', or force=True. Development task or
|
||||
task that has no associated worker is stopped immediately.
|
||||
For a non-development task with worker only the status message
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
"""Checks if there is an active worker running the task"""
|
||||
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
|
||||
return (
|
||||
t.last_worker
|
||||
and t.last_update
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
is_queued = task.status == TaskStatus.queued
|
||||
set_stopped = (
|
||||
is_queued
|
||||
or TaskSystemTags.development in task.system_tags
|
||||
or not is_run_by_worker(task)
|
||||
)
|
||||
|
||||
if set_stopped:
|
||||
if is_queued:
|
||||
try:
|
||||
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
new_status = task.status
|
||||
status_message = TaskStatusMessage.stopping
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=new_status,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
).execute()
|
||||
@@ -1,18 +1,22 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Callable, Tuple, Sequence, Union
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
deleted_prefix = "__DELETED__"
|
||||
|
||||
|
||||
@typed_attrs
|
||||
@@ -26,6 +30,7 @@ class ChangeStatusRequest(object):
|
||||
force = attr.ib(type=bool, default=False)
|
||||
allow_same_state_transition = attr.ib(type=bool, default=True)
|
||||
current_status_override = attr.ib(default=None)
|
||||
user_id = attr.ib(type=str, default=None)
|
||||
|
||||
def execute(self, **kwargs):
|
||||
current_status = self.current_status_override or self.task.status
|
||||
@@ -44,6 +49,7 @@ class ChangeStatusRequest(object):
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=self.user_id,
|
||||
)
|
||||
|
||||
if self.new_status == TaskStatus.queued:
|
||||
@@ -54,7 +60,7 @@ class ChangeStatusRequest(object):
|
||||
|
||||
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
||||
with translate_errors_context():
|
||||
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
||||
params = fields.copy()
|
||||
params.update(control)
|
||||
@@ -105,7 +111,7 @@ def validate_status_change(current_status, new_status):
|
||||
|
||||
state_machine = {
|
||||
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
|
||||
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
|
||||
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
|
||||
TaskStatus.in_progress: {
|
||||
TaskStatus.stopped,
|
||||
TaskStatus.failed,
|
||||
@@ -116,6 +122,7 @@ state_machine = {
|
||||
TaskStatus.closed,
|
||||
TaskStatus.created,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.queued,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.published,
|
||||
TaskStatus.publishing,
|
||||
@@ -153,41 +160,75 @@ def get_possible_status_changes(current_status):
|
||||
return possible
|
||||
|
||||
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
def get_many_tasks_for_writing(
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
query: Q = None,
|
||||
only: Sequence = None,
|
||||
throw_on_forbidden: bool = True,
|
||||
) -> Sequence[Task]:
|
||||
if only:
|
||||
missing = [f for f in ("company", ) if f not in only]
|
||||
if missing:
|
||||
only = [*only, *missing]
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def split_by(
|
||||
condition: Callable[[T], bool], items: Sequence[T]
|
||||
) -> Tuple[Sequence[T], Sequence[T]]:
|
||||
"""
|
||||
split "items" to two lists by "condition"
|
||||
"""
|
||||
applied = zip(map(condition, items), items)
|
||||
return (
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
result = list(
|
||||
Task.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
override_projection=only,
|
||||
allow_public=True,
|
||||
return_dicts=False,
|
||||
)
|
||||
)
|
||||
|
||||
forbidden_tasks = {task.id for task in result if not task.company}
|
||||
if forbidden_tasks:
|
||||
if throw_on_forbidden:
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
|
||||
)
|
||||
result = [task for task in result if task.id not in forbidden_tasks]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_task_with_write_access(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
only=None,
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
query = dict(id=task_id, company=company_id)
|
||||
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
identity: Identity,
|
||||
allow_all_statuses: bool = False,
|
||||
force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
only=("id", "status"),
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
@@ -202,9 +243,88 @@ def get_task_for_update(
|
||||
return task
|
||||
|
||||
|
||||
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
|
||||
def update_task(
|
||||
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
|
||||
):
|
||||
now = datetime.utcnow()
|
||||
last_updates = dict(last_change=now)
|
||||
last_updates = dict(last_change=now, last_changed_by=user_id)
|
||||
if set_last_update:
|
||||
last_updates.update(last_update=now)
|
||||
return task.update(**update_cmds, **last_updates)
|
||||
|
||||
|
||||
def get_last_metric_updates(
|
||||
task_id: str,
|
||||
last_scalar_events: dict,
|
||||
raw_updates: dict,
|
||||
extra_updates: dict,
|
||||
model_events: bool = False,
|
||||
):
|
||||
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
||||
total_metrics = set()
|
||||
if max_values:
|
||||
query = dict(id=task_id)
|
||||
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
||||
if to_add <= max_values:
|
||||
query[f"unique_metrics__{max_values - to_add}__exists"] = True
|
||||
db_cls = Model if model_events else Task
|
||||
task = db_cls.objects(**query).only("unique_metrics").first()
|
||||
if task and task.unique_metrics:
|
||||
total_metrics = set(task.unique_metrics)
|
||||
|
||||
new_metrics = []
|
||||
|
||||
def add_last_metric_conditional_update(
|
||||
metric_path: str, metric_value, iter_value: int, is_min: bool
|
||||
):
|
||||
"""
|
||||
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
|
||||
"""
|
||||
if is_min:
|
||||
field_prefix = "min"
|
||||
op = "$gt"
|
||||
else:
|
||||
field_prefix = "max"
|
||||
op = "$lt"
|
||||
|
||||
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
|
||||
condition = {
|
||||
"$or": [
|
||||
{"$lte": [f"${value_field}", None]},
|
||||
{op: [f"${value_field}", metric_value]},
|
||||
]
|
||||
}
|
||||
raw_updates[value_field] = {
|
||||
"$cond": [condition, metric_value, f"${value_field}"]
|
||||
}
|
||||
|
||||
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
|
||||
"__", "."
|
||||
)
|
||||
raw_updates[value_iteration_field] = {
|
||||
"$cond": [condition, iter_value, f"${value_iteration_field}"]
|
||||
}
|
||||
|
||||
for metric_key, metric_data in last_scalar_events.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
||||
if max_values:
|
||||
if len(total_metrics) >= max_values and metric not in total_metrics:
|
||||
continue
|
||||
total_metrics.add(metric)
|
||||
|
||||
new_metrics.append(metric)
|
||||
path = f"last_metrics__{metric_key}__{variant_key}"
|
||||
for key, value in variant_data.items():
|
||||
if key in ("min_value", "max_value"):
|
||||
add_last_metric_conditional_update(
|
||||
metric_path=path,
|
||||
metric_value=value,
|
||||
iter_value=variant_data.get(f"{key}_iter", 0),
|
||||
is_min=(key == "min_value"),
|
||||
)
|
||||
elif key in ("metric", "variant", "value"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
|
||||
if new_metrics:
|
||||
extra_updates["add_to_set__unique_metrics"] = new_metrics
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.users import CreateRequest
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -12,7 +14,7 @@ class UserBLL:
|
||||
if user_id and User.objects(id=user_id).only("id"):
|
||||
raise errors.bad_request.UserIdExists(id=user_id)
|
||||
|
||||
user = User(**request.to_struct())
|
||||
user = User(**request.to_struct(), created=datetime.utcnow())
|
||||
user.save(force_insert=True)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,84 +1,24 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
Iterable,
|
||||
Tuple,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.settings import Settings
|
||||
|
||||
|
||||
def extract_properties_to_lists(
|
||||
key_names: Sequence[str],
|
||||
data: Sequence[dict],
|
||||
extract_func: Optional[Callable[[dict], Tuple]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a list of dictionaries and names of dictionary keys
|
||||
builds a dictionary with the requested keys and values lists
|
||||
:param key_names: names of the keys in the resulting dictionary
|
||||
:param data: sequence of dictionaries to extract values from
|
||||
:param extract_func: the optional callable that extracts properties
|
||||
from a dictionary and put them in a tuple in the order corresponding to
|
||||
key_names. If not specified then properties are extracted according to key_names
|
||||
"""
|
||||
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
|
||||
return dict(zip(key_names, map(list, value_sequences)))
|
||||
|
||||
|
||||
class SetFieldsResolver:
|
||||
"""
|
||||
The class receives set fields dictionary
|
||||
and for the set fields that require 'min' or 'max'
|
||||
operation replace them with a simple set in case the
|
||||
DB document does not have these fields set
|
||||
"""
|
||||
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
return self.fields[name]
|
||||
return name
|
||||
|
||||
def get_fields(self, doc: AttributedDocument):
|
||||
"""
|
||||
For the given document return the set fields instructions
|
||||
with min/max operations replaced with a single set in case
|
||||
the document does not have the field set
|
||||
"""
|
||||
return {
|
||||
self._get_updated_name(doc, name): value
|
||||
for name, value in self.orig_fields.items()
|
||||
}
|
||||
|
||||
def get_names(self) -> Set[str]:
|
||||
"""
|
||||
Returns the names of the fields that had min/max modifiers
|
||||
in the format suitable for projection (dot separated)
|
||||
"""
|
||||
return set(name.replace("__", ".") for name in self.fields.values())
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_server_uuid() -> Optional[str]:
|
||||
return Settings.get_by_key("server.uuid")
|
||||
@@ -115,3 +55,38 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_batch_operation(
|
||||
func: Callable[[str], T], ids: Sequence[str]
|
||||
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
|
||||
results = list()
|
||||
failures = list()
|
||||
for _id in ids:
|
||||
try:
|
||||
results.append((_id, func(_id)))
|
||||
except APIError as err:
|
||||
failures.append(
|
||||
{
|
||||
"id": _id,
|
||||
"error": {
|
||||
"codes": [err.code, err.subcode],
|
||||
"msg": err.msg,
|
||||
"data": err.error_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
return results, failures
|
||||
|
||||
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
from boltons.iterutils import partition, chunked_iter
|
||||
from pyhocon import ConfigTree
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.errors import bad_request, server_error
|
||||
from apiserver.apimodels.workers import (
|
||||
DEFAULT_TIMEOUT,
|
||||
IdNameEntry,
|
||||
WorkerEntry,
|
||||
StatusReportRequest,
|
||||
@@ -25,16 +27,17 @@ from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.config = config.get("services.workers", ConfigTree())
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@@ -51,6 +54,7 @@ class WorkerBLL:
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -66,7 +70,7 @@ class WorkerBLL:
|
||||
"""
|
||||
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
timeout = timeout or DEFAULT_TIMEOUT
|
||||
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
|
||||
queues = queues or []
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -76,7 +80,7 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidUserId(**query)
|
||||
company = Company.objects(id=company_id).only("id", "name").first()
|
||||
if not company:
|
||||
raise server_error.InternalError("invalid company", company=company_id)
|
||||
raise bad_request.InvalidId("invalid company", company=company_id)
|
||||
|
||||
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
|
||||
if len(queue_objs) < len(queues):
|
||||
@@ -95,9 +99,10 @@ class WorkerBLL:
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
self._save_worker_data(entry)
|
||||
|
||||
return entry
|
||||
|
||||
@@ -109,15 +114,20 @@ class WorkerBLL:
|
||||
:param worker: worker ID
|
||||
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
|
||||
"""
|
||||
with TimingContext("redis", "workers_unregister"):
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res:
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res and not config.get("apiserver.workers.auto_unregister", False):
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
ip: str,
|
||||
report: StatusReportRequest,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
@@ -133,22 +143,23 @@ class WorkerBLL:
|
||||
|
||||
try:
|
||||
entry.ip = ip
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
if system_tags is not None:
|
||||
entry.system_tags = system_tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
self.log_stats_to_es(
|
||||
company_id=company_id,
|
||||
company_name=entry.company.name,
|
||||
worker=report.worker,
|
||||
worker_id=report.worker,
|
||||
timestamp=report.timestamp,
|
||||
task=report.task,
|
||||
machine_stats=report.machine_stats,
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
entry.queue = report.queue
|
||||
|
||||
if report.queues:
|
||||
@@ -165,6 +176,7 @@ class WorkerBLL:
|
||||
last_worker_report=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
# modify(new=True, ...) returns the modified object
|
||||
task = Task.objects(**query).modify(new=True, **update)
|
||||
@@ -176,7 +188,9 @@ class WorkerBLL:
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
entry.project = IdNameEntry(
|
||||
id=project.id, name=project.name
|
||||
)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
@@ -188,8 +202,30 @@ class WorkerBLL:
|
||||
finally:
|
||||
self._save_worker(entry)
|
||||
|
||||
def get_count(
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: Optional[int] = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
):
|
||||
if not last_seen:
|
||||
return len(
|
||||
self._get_keys(company_id, user_tags=tags, system_tags=system_tags)
|
||||
)
|
||||
|
||||
return len(
|
||||
self.get_all(
|
||||
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags
|
||||
)
|
||||
)
|
||||
|
||||
def get_all(
|
||||
self, company_id: str, last_seen: Optional[int] = None
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: Optional[int] = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""
|
||||
Get all the company workers that were active during the last_seen period
|
||||
@@ -198,7 +234,7 @@ class WorkerBLL:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workers = self._get(company_id)
|
||||
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
|
||||
except Exception as e:
|
||||
raise server_error.DataError("failed loading worker entries", err=e.args[0])
|
||||
|
||||
@@ -213,15 +249,21 @@ class WorkerBLL:
|
||||
return workers
|
||||
|
||||
def get_all_with_projection(
|
||||
self, company_id: str, last_seen: int
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: int,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(company_id=company_id, last_seen=last_seen),
|
||||
helpers = [
|
||||
WorkerConversionHelper.from_worker_entry(entry)
|
||||
for entry in self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
|
||||
all_queues = set(
|
||||
@@ -240,9 +282,7 @@ class WorkerBLL:
|
||||
}
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
None,
|
||||
@@ -258,7 +298,7 @@ class WorkerBLL:
|
||||
tasks_info = {
|
||||
task.id: task
|
||||
for task in Task.objects(id__in=task_ids).only(
|
||||
"name", "started", "last_iteration"
|
||||
"name", "started", "last_iteration", "active_duration"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -283,11 +323,7 @@ class WorkerBLL:
|
||||
if helper.task_id:
|
||||
task = tasks_info.get(helper.task_id, None)
|
||||
if task:
|
||||
worker.task.running_time = (
|
||||
int((datetime.utcnow() - task.started).total_seconds() * 1000)
|
||||
if task.started
|
||||
else 0
|
||||
)
|
||||
worker.task.running_time = (task.active_duration or 0) * 1000
|
||||
worker.task.last_iteration = task.last_iteration
|
||||
|
||||
update_queue_entries(worker.queue)
|
||||
@@ -314,8 +350,7 @@ class WorkerBLL:
|
||||
"""
|
||||
key = self._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
with TimingContext("redis", "get_worker"):
|
||||
data = self.redis.get(key)
|
||||
data = self.redis.get(key)
|
||||
|
||||
if data:
|
||||
try:
|
||||
@@ -342,42 +377,149 @@ class WorkerBLL:
|
||||
|
||||
raise bad_request.InvalidWorkerId(worker=worker)
|
||||
|
||||
@staticmethod
|
||||
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers.{tags_field}_{company}_{tag}"
|
||||
|
||||
@staticmethod
|
||||
def _get_all_workers_key(company: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers_{company}"
|
||||
|
||||
def _save_worker_data(self, entry: WorkerEntry):
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
company_id = entry.company.id
|
||||
expiration = int(time()) + entry.register_timeout
|
||||
worker_item = {entry.key: expiration}
|
||||
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
|
||||
for tags, tags_field in (
|
||||
(entry.tags, "tags"),
|
||||
(entry.system_tags, "systemtags"),
|
||||
):
|
||||
for tag in tags:
|
||||
name = self._get_tagged_workers_key(company_id, tags_field, tag)
|
||||
self.redis.zadd(name, worker_item)
|
||||
|
||||
def _save_worker(self, entry: WorkerEntry) -> None:
|
||||
"""Save worker entry in Redis"""
|
||||
try:
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
self._save_worker_data(entry)
|
||||
except Exception:
|
||||
msg = "Failed saving worker entry"
|
||||
log.exception(msg)
|
||||
|
||||
def _get_keys(
|
||||
self,
|
||||
company: str,
|
||||
user: str = "*",
|
||||
user_tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[bytes]:
|
||||
if not (user_tags or system_tags):
|
||||
match = self._get_worker_key(company, user, "*")
|
||||
return list(self.redis.scan_iter(match))
|
||||
|
||||
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
||||
if user == "*":
|
||||
return in_keys
|
||||
user_bytes = user.encode()
|
||||
return {k for k in in_keys if user_bytes in k}
|
||||
|
||||
worker_keys = set()
|
||||
for tags, tags_field in (
|
||||
(user_tags, "tags"),
|
||||
(system_tags, "systemtags"),
|
||||
):
|
||||
if not tags:
|
||||
continue
|
||||
|
||||
timestamp = int(time())
|
||||
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
||||
if include:
|
||||
tagged_workers = set()
|
||||
for tag in include:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
|
||||
|
||||
tagged_workers = filter_by_user(tagged_workers)
|
||||
worker_keys = (
|
||||
worker_keys.intersection(tagged_workers)
|
||||
if worker_keys
|
||||
else tagged_workers
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
|
||||
if exclude:
|
||||
if not worker_keys:
|
||||
all_workers_key = self._get_all_workers_key(company)
|
||||
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
|
||||
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
||||
worker_keys = filter_by_user(worker_keys)
|
||||
if not worker_keys:
|
||||
return []
|
||||
|
||||
for tag in exclude:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag[1:]
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.difference_update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
|
||||
return list(worker_keys)
|
||||
|
||||
def _get(
|
||||
self, company: str, user: str = "*", worker_id: str = "*"
|
||||
self,
|
||||
company: str,
|
||||
user: str = "*",
|
||||
user_tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
match = self._get_worker_key(company, user, worker_id)
|
||||
with TimingContext("redis", "workers_get_all"):
|
||||
res = self.redis.scan_iter(match)
|
||||
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
|
||||
|
||||
entries = []
|
||||
for keys in chunked_iter(
|
||||
self._get_keys(
|
||||
company, user=user, user_tags=user_tags, system_tags=system_tags
|
||||
),
|
||||
1000,
|
||||
):
|
||||
data = self.redis.mget(keys)
|
||||
if data:
|
||||
entries.extend(WorkerEntry.from_json(d) for d in data if d)
|
||||
|
||||
return entries
|
||||
|
||||
@staticmethod
|
||||
def _get_es_index_suffix():
|
||||
"""Get the index name suffix for storing current month data"""
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
def _log_stats_to_es(
|
||||
def log_stats_to_es(
|
||||
self,
|
||||
company_id: str,
|
||||
company_name: str,
|
||||
worker: str,
|
||||
worker_id: str,
|
||||
timestamp: int,
|
||||
task: str,
|
||||
machine_stats: MachineStats,
|
||||
) -> bool:
|
||||
) -> int:
|
||||
"""
|
||||
Actually writing the worker statistics to Elastic
|
||||
:return: True if successful, False otherwise
|
||||
:return: The amount of logged documents
|
||||
"""
|
||||
es_index = (
|
||||
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
|
||||
@@ -389,8 +531,7 @@ class WorkerBLL:
|
||||
_index=es_index,
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
company=company_name,
|
||||
worker=worker_id,
|
||||
task=task,
|
||||
category=category,
|
||||
metric=metric,
|
||||
@@ -415,7 +556,7 @@ class WorkerBLL:
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
return added
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
|
||||
@@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -20,7 +19,7 @@ class WorkerStats:
|
||||
@staticmethod
|
||||
def worker_stats_prefix_for_company(company_id: str) -> str:
|
||||
"""Returns the es index prefix for the company"""
|
||||
return f"worker_stats_{company_id}_"
|
||||
return f"worker_stats_{company_id.lower()}_"
|
||||
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
@@ -126,7 +125,7 @@ class WorkerStats:
|
||||
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
|
||||
es_req["query"] = {"bool": {"must": query_terms}}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
return self._extract_results(data, request.items, request.split_by_variant)
|
||||
@@ -216,6 +215,10 @@ class WorkerStats:
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"extended_bounds": {
|
||||
"min": int(from_date) * 1000,
|
||||
"max": int(to_date) * 1000,
|
||||
}
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
@@ -223,9 +226,7 @@ class WorkerStats:
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_worker_activity_report"
|
||||
):
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if "aggregations" not in data:
|
||||
|
||||
@@ -6,9 +6,10 @@ from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
from pathlib import Path
|
||||
from typing import List, Any, TypeVar
|
||||
from typing import List, Any, TypeVar, Sequence
|
||||
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from boltons.iterutils import first
|
||||
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
|
||||
from pyparsing import (
|
||||
ParseFatalException,
|
||||
ParseException,
|
||||
@@ -18,8 +19,8 @@ from pyparsing import (
|
||||
|
||||
from apiserver.utilities import json
|
||||
|
||||
EXTRA_CONFIG_PATHS = ("/opt/trains/config",)
|
||||
EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
|
||||
DEFAULT_PREFIXES = ("clearml", "trains")
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
|
||||
|
||||
|
||||
@@ -30,7 +31,10 @@ class BasicConfig:
|
||||
default_config_dir = "default"
|
||||
|
||||
def __init__(
|
||||
self, folder: str = None, verbose: bool = True, prefix: str = "trains"
|
||||
self,
|
||||
folder: str = None,
|
||||
verbose: bool = True,
|
||||
prefix: Sequence[str] = DEFAULT_PREFIXES,
|
||||
):
|
||||
folder = (
|
||||
Path(folder)
|
||||
@@ -41,8 +45,16 @@ class BasicConfig:
|
||||
raise ValueError("Invalid configuration folder")
|
||||
|
||||
self.verbose = verbose
|
||||
self.prefix = prefix
|
||||
self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__"
|
||||
|
||||
self.extra_config_path_override_var = [
|
||||
f"{p.upper()}_CONFIG_DIR" for p in prefix
|
||||
]
|
||||
|
||||
self.prefix = prefix[0]
|
||||
self.extra_config_values_env_key_prefix = [
|
||||
f"{p.upper()}{self.extra_config_values_env_key_sep}"
|
||||
for p in reversed(prefix)
|
||||
]
|
||||
|
||||
self._paths = [folder, *self._get_paths()]
|
||||
self._config = self._reload()
|
||||
@@ -67,30 +79,32 @@ class BasicConfig:
|
||||
def logger(self, name: str) -> logging.Logger:
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
if name == "__init__" and Path(name).parent.stem:
|
||||
name = Path(name).parent.stem
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self) -> ConfigTree:
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = self.extra_config_values_env_key_prefix
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = (
|
||||
key[len(prefix) :]
|
||||
.replace(self.extra_config_values_env_key_sep, ".")
|
||||
.lower()
|
||||
)
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
for prefix in self.extra_config_values_env_key_prefix:
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = (
|
||||
key[len(prefix) :]
|
||||
.replace(self.extra_config_values_env_key_sep, ".")
|
||||
.lower()
|
||||
)
|
||||
result = self._merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _get_paths(self) -> List[Path]:
|
||||
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
|
||||
value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths)
|
||||
value = first(map(getenv, self.extra_config_path_override_var), default_paths)
|
||||
|
||||
paths = [
|
||||
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
|
||||
@@ -100,7 +114,7 @@ class BasicConfig:
|
||||
invalid = [path for path in paths if not path.is_dir()]
|
||||
if invalid:
|
||||
print(
|
||||
f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}"
|
||||
f"WARNING: Invalid paths in {self.extra_config_path_override_var} env var: {' '.join(map(str, invalid))}"
|
||||
)
|
||||
|
||||
return [path for path in paths if path.is_dir()]
|
||||
@@ -114,13 +128,40 @@ class BasicConfig:
|
||||
configs = [self._read_recursive(path) for path in self._paths]
|
||||
|
||||
return reduce(
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
lambda last, config: self._merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _merge_configs(cls, a, b, copy_trees=False, override_prefix="-"):
|
||||
"""Based on pyhocon.ConfigTree.merge_configs, with dict override support using a `-` key prefix"""
|
||||
for key, value in b.items():
|
||||
override = key.startswith(override_prefix)
|
||||
if override:
|
||||
key = key[len(override_prefix):]
|
||||
# if key is in both a and b and both values are dictionary then merge it otherwise override it
|
||||
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
|
||||
if copy_trees:
|
||||
a[key] = a[key].copy()
|
||||
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
|
||||
else:
|
||||
if isinstance(value, ConfigValues):
|
||||
value.parent = a
|
||||
value.key = key
|
||||
if key in a:
|
||||
value.overriden_value = a[key]
|
||||
a[key] = value
|
||||
if a.root:
|
||||
if b.root:
|
||||
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
|
||||
else:
|
||||
a.history[key] = a.history.get(key, []) + [value]
|
||||
|
||||
return a
|
||||
|
||||
def _read_recursive(self, conf_root) -> ConfigTree:
|
||||
conf = ConfigTree()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
debug: false # Debug mode
|
||||
pretty_json: false # prettify json response
|
||||
return_stack: true # return stack trace on error
|
||||
log_calls: true # Log API Calls
|
||||
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
|
||||
|
||||
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
|
||||
# valid values are:
|
||||
@@ -41,10 +41,6 @@
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
strict: false
|
||||
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
elastic {
|
||||
@@ -69,7 +65,7 @@
|
||||
default_expiration_sec: 2592000
|
||||
|
||||
# cookie containing auth token, for requests arriving from a web-browser
|
||||
session_auth_cookie_name: "trains_token_basic"
|
||||
session_auth_cookie_name: "clearml_token_basic"
|
||||
|
||||
# cookie configuration for authorization cookies generated by auth.login
|
||||
cookies {
|
||||
@@ -79,9 +75,16 @@
|
||||
max_age: 99999999999
|
||||
}
|
||||
|
||||
# provide a cookie domain override per company
|
||||
# cookies_domain_override {
|
||||
# <company-id>: <domain>
|
||||
# }
|
||||
|
||||
# # A list of fixed users
|
||||
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
|
||||
# fixed_users {
|
||||
# enabled: true
|
||||
# pass_hashed: false
|
||||
# users: [
|
||||
# {
|
||||
# username: "john"
|
||||
@@ -105,9 +108,15 @@
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
|
||||
auto_unregister: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
|
||||
# Timeout in seconds for worker registration (or status report). If a worker did not report for this long,
|
||||
# it is discarded from the server's table
|
||||
default_timeout: 600
|
||||
}
|
||||
|
||||
check_for_updates {
|
||||
@@ -116,9 +125,9 @@
|
||||
# Check for updates every 24 hours
|
||||
check_interval_sec: 86400
|
||||
|
||||
url: "https://updates.trains.allegro.ai/updates"
|
||||
url: "https://updates.clear.ml/updates"
|
||||
|
||||
component_name: "trains-server"
|
||||
component_name: "clearml-server"
|
||||
|
||||
# GET request timeout
|
||||
request_timeout_sec: 3.0
|
||||
@@ -128,7 +137,7 @@
|
||||
# Note: statistics are sent ONLY if the user has actively opted-in
|
||||
supported: true
|
||||
|
||||
url: "https://updates.trains.allegro.ai/stats"
|
||||
url: "https://updates.clear.ml/stats"
|
||||
|
||||
report_interval_hours: 24
|
||||
agent_relevant_threshold_days: 30
|
||||
@@ -137,4 +146,11 @@
|
||||
max_backoff_sec: 5
|
||||
}
|
||||
|
||||
getting_started_info {
|
||||
"agentName": "clearml",
|
||||
"configure": "clearml-init",
|
||||
"install": "pip install clearml",
|
||||
"packageName": "clearml"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
fileserver = "http://localhost:8081"
|
||||
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
backupCount: 3
|
||||
maxBytes: 10240000,
|
||||
class: "logging.handlers.RotatingFileHandler",
|
||||
filename: "/var/log/trains/apiserver.log"
|
||||
filename: "/var/log/clearml/apiserver.log"
|
||||
}
|
||||
}
|
||||
root {
|
||||
|
||||
@@ -23,11 +23,17 @@
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
services_agent {
|
||||
role: "admin"
|
||||
user_key: "P4BMJA7RK3TKBXGSY8OAA1FA8TOD11"
|
||||
user_secret: "9LsgSfa0SYz0zli1_c500ZcLqanre2xkWOpepyt1w-BKK3_DKPHrtoj3JSHvyy8bIi0"
|
||||
}
|
||||
tests {
|
||||
role: "user"
|
||||
display_name: "Default User"
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
revoke_in_fixed_mode: true
|
||||
}
|
||||
}
|
||||
}
|
||||
9
apiserver/config/default/services/_mongo.conf
Normal file
9
apiserver/config/default/services/_mongo.conf
Normal file
@@ -0,0 +1,9 @@
|
||||
max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
|
||||
allow_disk_use {
|
||||
sort: true
|
||||
aggregate: true
|
||||
}
|
||||
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
@@ -0,0 +1,12 @@
|
||||
# if set to true then on task delete/reset external file urls for known storage types are scheduled for async delete
|
||||
# otherwise they are returned to a client for the client side delete
|
||||
enabled: true
|
||||
max_retries: 3
|
||||
retry_timeout_sec: 60
|
||||
|
||||
fileserver {
|
||||
# fileserver url prefixes. Evaluated in the order of priority
|
||||
# Can be in the form <schema>://host:port/path or /path
|
||||
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
|
||||
timeout_sec: 300
|
||||
}
|
||||
@@ -12,11 +12,28 @@ events_retrieval {
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
# If set then max_metrics_count and max_variants_count are calculated dynamically on user data
|
||||
dynamic_metrics_count: true
|
||||
|
||||
# The percentage from the ES aggs limit (10000) to use for the max_metrics and max_variants calculation
|
||||
dynamic_metrics_count_threshold: 80
|
||||
|
||||
# the max amount of metrics to aggregate on
|
||||
max_metrics_count: 100
|
||||
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
|
||||
debug_images {
|
||||
# Allow to return the debug images for the variants with uninitialized valid iterations border
|
||||
allow_uninitialized_variants: true
|
||||
}
|
||||
|
||||
max_raw_scalars_size: 200000
|
||||
|
||||
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
|
||||
|
||||
multi_plots_batch_size: 1000
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
@@ -24,4 +41,7 @@ events_retrieval {
|
||||
validate_plot_str: false
|
||||
|
||||
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||
plot_compression_threshold: 100000
|
||||
plot_compression_threshold: 100000
|
||||
|
||||
# async events delete threshold
|
||||
max_async_deleted_events_per_sec: 1000
|
||||
4
apiserver/config/default/services/models.conf
Normal file
4
apiserver/config/default/services/models.conf
Normal file
@@ -0,0 +1,4 @@
|
||||
metadata_values {
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
@@ -1,3 +1,9 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
download {
|
||||
redis_timeout_sec: 300
|
||||
batch_size: 500
|
||||
max_download_items: 50000
|
||||
max_project_name_length: 60
|
||||
}
|
||||
@@ -10,4 +10,9 @@ featured {
|
||||
|
||||
# default featured index for public projects not specified in the order
|
||||
public_default: 9999
|
||||
}
|
||||
|
||||
sub_projects {
|
||||
# the max sub project depth
|
||||
max_depth: 10
|
||||
}
|
||||
8
apiserver/config/default/services/queues.conf
Normal file
8
apiserver/config/default/services/queues.conf
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
metrics_before_from_date: 3600
|
||||
# interval in seconds to update queue metrics. Put 0 to disable
|
||||
metrics_refresh_interval_sec: 300
|
||||
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
|
||||
# or search_hidden is set
|
||||
hidden_tags: [k8s-glue]
|
||||
}
|
||||
53
apiserver/config/default/services/storage_credentials.conf
Normal file
53
apiserver/config/default/services/storage_credentials.conf
Normal file
@@ -0,0 +1,53 @@
|
||||
aws {
|
||||
s3 {
|
||||
# S3 credentials, used for read/write access by various SDK elements
|
||||
# default, used for any bucket not specified below
|
||||
key: ""
|
||||
secret: ""
|
||||
region: ""
|
||||
use_credentials_chain: false
|
||||
# Additional ExtraArgs passed to boto3 when uploading files. Can also be set per-bucket under "credentials".
|
||||
extra_args: {}
|
||||
credentials: [
|
||||
# specifies key/secret credentials to use when handling s3 urls (read or write)
|
||||
# {
|
||||
# bucket: "my-bucket-name"
|
||||
# key: "my-access-key"
|
||||
# secret: "my-secret-key"
|
||||
# },
|
||||
{
|
||||
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
host: "localhost:9000"
|
||||
key: "evg_user"
|
||||
secret: "evg_pass"
|
||||
multipart: false
|
||||
secure: false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
google.storage {
|
||||
# Default project and credentials file
|
||||
# Will be used when no bucket configuration is found
|
||||
// project: "clearml"
|
||||
// credentials_json: "/path/to/credentials.json"
|
||||
//
|
||||
// # Specific credentials per bucket and sub directory
|
||||
// credentials = [
|
||||
// {
|
||||
// bucket: "my-bucket"
|
||||
// subdir: "path/in/bucket" # Not required
|
||||
// project: "clearml"
|
||||
// credentials_json: "/path/to/credentials.json"
|
||||
// },
|
||||
// ]
|
||||
}
|
||||
azure.storage {
|
||||
# containers: [
|
||||
# {
|
||||
# account_name: "clearml"
|
||||
# account_key: "secret"
|
||||
# # container_name:
|
||||
# }
|
||||
# ]
|
||||
}
|
||||
@@ -9,3 +9,20 @@ non_responsive_tasks_watchdog {
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
||||
hyperparam_values {
|
||||
# max allowed outdate time for the cashed result
|
||||
cache_allowed_outdate_sec: 60
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
|
||||
# the maximum amount of unique last metrics/variants combinations
|
||||
# for which the last values are stored in a task
|
||||
max_last_metrics: 2000
|
||||
|
||||
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
|
||||
async_events_delete: true
|
||||
# do not use async_delete if the deleted task has amount of events lower than this threshold
|
||||
async_events_delete_threshold: 100000
|
||||
|
||||
@@ -2,6 +2,8 @@ from functools import lru_cache
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.version import __version__
|
||||
|
||||
@@ -9,7 +11,9 @@ root = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def _get(prop_name, env_suffix=None, default=""):
|
||||
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
|
||||
suffix = env_suffix or prop_name
|
||||
keys = [f"{p}_SERVER_{suffix}" for p in ("CLEARML", "TRAINS")]
|
||||
value = first(map(getenv, keys))
|
||||
if value:
|
||||
return value
|
||||
|
||||
|
||||
@@ -17,11 +17,21 @@ log = config.logger("database")
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_MONGODB_SERVICE_HOST",
|
||||
"TRAINS_MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_MONGODB_SERVICE_PORT",
|
||||
"TRAINS_MONGODB_SERVICE_PORT",
|
||||
"MONGODB_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
|
||||
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
|
||||
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
|
||||
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
@@ -32,45 +42,74 @@ class DatabaseEntry(models.Base):
|
||||
class DatabaseFactory:
|
||||
_entries = []
|
||||
|
||||
@classmethod
|
||||
def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry:
|
||||
return DatabaseEntry(alias=alias, **settings)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
|
||||
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
|
||||
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
|
||||
|
||||
if override_connection_string:
|
||||
log.info(f"Using override mongodb connection string template {override_connection_string}")
|
||||
else:
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
if override_username:
|
||||
log.info(f"Using override mongodb username {override_username}")
|
||||
if override_password:
|
||||
log.info(f"Using override mongodb password ******")
|
||||
if override_query:
|
||||
log.info(f"Using override mongodb query {override_query}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_connection_string:
|
||||
con_str = f"{override_connection_string.rstrip('/')}/{key}"
|
||||
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
|
||||
entry.host = con_str
|
||||
else:
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_username:
|
||||
entry.host = furl(entry.host).set(username=override_username).url
|
||||
if override_password:
|
||||
entry.host = furl(entry.host).set(password=override_password).url
|
||||
if override_query:
|
||||
entry.host = furl(entry.host).set(query=override_query).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
register_connection(**entry.to_struct())
|
||||
|
||||
cls._entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
|
||||
raise ValueError(
|
||||
"Missing database configuration for %s" % ", ".join(missing)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_entries(cls):
|
||||
@@ -91,7 +130,7 @@ class DatabaseFactory:
|
||||
# reconnection from work so workaround this
|
||||
# get_connection(entry.alias, reconnect=True)
|
||||
disconnect(entry.alias)
|
||||
register_connection(alias=entry.alias, host=entry.host)
|
||||
register_connection(**entry.to_struct())
|
||||
get_connection(entry.alias)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from mongoengine.errors import (
|
||||
LookUpError,
|
||||
InvalidQueryError,
|
||||
)
|
||||
from pymongo.errors import PyMongoError, NotMasterError
|
||||
from pymongo.errors import PyMongoError, NotPrimaryError
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
@@ -166,7 +166,10 @@ class MongoEngineErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.InternalError)
|
||||
def invalid_query_error(cls, e, message, **_):
|
||||
pass
|
||||
if e.args:
|
||||
inner = e.args[0]
|
||||
if isinstance(inner, LookUpError):
|
||||
cls.lookup_error(inner, message)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -195,7 +198,7 @@ def translate_errors_context(message=None, **kwargs):
|
||||
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
|
||||
except PyMongoError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except NotMasterError as e:
|
||||
except NotPrimaryError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except MakeGetAllQueryError as e:
|
||||
raise errors.bad_request.ValidationError(e.error, field=e.field)
|
||||
|
||||
@@ -176,6 +176,13 @@ class SafeMapField(MapField, DictValidationMixin):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class NullableStringField(StringField):
|
||||
def validate(self, value):
|
||||
if value is None:
|
||||
return
|
||||
super(NullableStringField, self).validate(value)
|
||||
|
||||
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
self._safe_validate(value)
|
||||
|
||||
@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
hidden = "hidden"
|
||||
|
||||
@@ -48,7 +48,9 @@ class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
44
apiserver/database/model/metadata.py
Normal file
44
apiserver/database/model/metadata.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Sequence, Type
|
||||
|
||||
from mongoengine import EmbeddedDocument, StringField, Document
|
||||
from pymongo import UpdateOne
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from apiserver.database.model.base import ProperDictMixin
|
||||
|
||||
|
||||
class MetadataItem(EmbeddedDocument, ProperDictMixin):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
|
||||
collection: Collection = cls._get_collection()
|
||||
res = collection.update_one(
|
||||
filter={"_id": _id},
|
||||
update={
|
||||
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
|
||||
},
|
||||
array_filters=[
|
||||
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
|
||||
],
|
||||
upsert=False,
|
||||
)
|
||||
if len(items) == 1 and res.modified_count == 1:
|
||||
return res.modified_count
|
||||
|
||||
requests = [
|
||||
UpdateOne(
|
||||
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
|
||||
update={"$push": {"metadata": item}},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
res = collection.bulk_write(requests)
|
||||
|
||||
return 1 if res.modified_count else 0
|
||||
|
||||
|
||||
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
|
||||
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)
|
||||
@@ -1,17 +1,34 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentField,
|
||||
IntField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.metrics import MetricEvent
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
class Model(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -19,9 +36,11 @@ class Model(DbModelMixin, Document):
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
"last_update",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "uri"),
|
||||
{
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
@@ -50,14 +69,15 @@ class Model(DbModelMixin, Document):
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
range_fields=("last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
@@ -69,7 +89,19 @@ class Model(DbModelMixin, Document):
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
last_update = DateTimeField()
|
||||
last_change = DateTimeField()
|
||||
last_changed_by = StringField()
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
last_iteration = IntField(default=0)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
return self.company or self.company_origin or ""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
from mongoengine import StringField, DateTimeField, IntField, ListField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
@@ -9,15 +9,19 @@ from apiserver.database.model.base import GetMixin
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name", "basename", "description"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
range_fields=("last_update",),
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"path",
|
||||
("company", "name"),
|
||||
("company", "basename"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
@@ -34,7 +38,8 @@ class Project(AttributedDocument):
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
description = StringField(required=True)
|
||||
basename = StrippedStringField(required=True)
|
||||
description = StringField()
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
@@ -44,3 +49,5 @@ class Project(AttributedDocument):
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
parent = StringField(reference_field="Project")
|
||||
path = ListField(StringField(required=True), exclude_by_default=True)
|
||||
|
||||
@@ -4,34 +4,43 @@ from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin, AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
""" Task ID """
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
""" Added to the queue """
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -40,7 +49,12 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
tags = SafeSortedListField(
|
||||
StringField(required=True), default=list, user_set_allowed=True
|
||||
)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from mongoengine import (
|
||||
DynamicField,
|
||||
LongField,
|
||||
EmbeddedDocumentField,
|
||||
IntField,
|
||||
)
|
||||
|
||||
from apiserver.database.fields import SafeMapField
|
||||
@@ -19,7 +20,9 @@ class MetricEvent(EmbeddedDocument):
|
||||
variant = StringField(required=True)
|
||||
value = DynamicField(required=True)
|
||||
min_value = DynamicField() # for backwards compatibility reasons
|
||||
min_value_iteration = IntField()
|
||||
max_value = DynamicField() # for backwards compatibility reasons
|
||||
max_value_iteration = IntField()
|
||||
|
||||
|
||||
class EventStats(EmbeddedDocument):
|
||||
|
||||
@@ -11,6 +11,5 @@ class Result(object):
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = StrippedStringField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
@@ -17,6 +17,9 @@ from apiserver.database.fields import (
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
SafeSortedListField,
|
||||
EmbeddedDocumentListField,
|
||||
NullableStringField,
|
||||
NoneType,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
@@ -79,13 +82,17 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
|
||||
mode = StringField(
|
||||
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
timestamp = LongField()
|
||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
display_data = SafeSortedListField(
|
||||
ListField(UnionField((int, float, str, NoneType)))
|
||||
)
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
@@ -103,17 +110,37 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class TaskModelTypes:
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
TaskModelNames = {
|
||||
TaskModelTypes.input: "Input Model",
|
||||
TaskModelTypes.output: "Output Model",
|
||||
}
|
||||
|
||||
|
||||
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True, reference_field="Model")
|
||||
updated = DateTimeField()
|
||||
|
||||
|
||||
class Models(EmbeddedDocument, ProperDictMixin):
|
||||
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
||||
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
queue = StringField(reference_field="Queue")
|
||||
""" Queue ID where task was queued """
|
||||
|
||||
|
||||
@@ -125,6 +152,7 @@ class TaskType(object):
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
report = "report"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
@@ -135,12 +163,10 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
"execution.parameters.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
"hyperparams.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -153,6 +179,9 @@ class Task(AttributedDocument):
|
||||
"active_duration",
|
||||
"parent",
|
||||
"project",
|
||||
"last_update",
|
||||
"status_changed",
|
||||
"models.input.model",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "status", "type"),
|
||||
@@ -160,14 +189,19 @@ class Task(AttributedDocument):
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{
|
||||
"fields": ["company", "project"],
|
||||
"collation": AttributedDocument._numeric_locale,
|
||||
},
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$execution.model",
|
||||
"$output.model",
|
||||
"$report",
|
||||
"$models.input.model",
|
||||
"$models.output.model",
|
||||
"$script.repository",
|
||||
"$script.entry_point",
|
||||
],
|
||||
@@ -176,8 +210,9 @@ class Task(AttributedDocument):
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"execution.model": 2,
|
||||
"output.model": 2,
|
||||
"report": 10,
|
||||
"models.output.model": 2,
|
||||
"models.input.model": 2,
|
||||
"script.repository": 1,
|
||||
"script.entry_point": 1,
|
||||
},
|
||||
@@ -185,9 +220,22 @@ class Task(AttributedDocument):
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
list_fields=(
|
||||
"id",
|
||||
"user",
|
||||
"tags",
|
||||
"system_tags",
|
||||
"type",
|
||||
"status",
|
||||
"project",
|
||||
"parent",
|
||||
"hyperparams.*",
|
||||
"execution.queue",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
pattern_fields=("name", "comment", "report"),
|
||||
fields=("runtime.*", "models.input.model"),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -198,9 +246,11 @@ class Task(AttributedDocument):
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_message = StringField(user_set_allowed=True)
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
report = StringField()
|
||||
report_assets = ListField(StringField())
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
started = DateTimeField()
|
||||
completed = DateTimeField()
|
||||
@@ -219,13 +269,19 @@ class Task(AttributedDocument):
|
||||
last_change = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
duration = IntField() # obsolete, do not use
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
docker_init_script = StringField()
|
||||
models: Models = EmbeddedDocumentField(Models, default=Models)
|
||||
container = SafeMapField(field=NullableStringField())
|
||||
enqueue_status = StringField(
|
||||
choices=get_options(TaskStatus), exclude_by_default=True
|
||||
)
|
||||
last_changed_by = StringField()
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
"""
|
||||
|
||||
52
apiserver/database/model/url_to_delete.py
Normal file
52
apiserver/database/model/url_to_delete.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import StringField, DateTimeField, IntField, EnumField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import AttributedDocument
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
fileserver = "fileserver"
|
||||
s3 = "s3"
|
||||
azure = "azure"
|
||||
gs = "gs"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
file = "file"
|
||||
folder = "folder"
|
||||
|
||||
|
||||
class DeletionStatus(str, Enum):
|
||||
created = "created"
|
||||
retrying = "retrying"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class UrlToDelete(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"url": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "user", "task"),
|
||||
("company", "storage_type", "url"),
|
||||
("status", "retry_count", "storage_type"),
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
url = StringField(required=True, unique_with="company")
|
||||
task = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
storage_type = EnumField(StorageType, default=StorageType.unknown)
|
||||
type = EnumField(FileType, default=FileType.file)
|
||||
retry_count = IntField(default=0)
|
||||
last_failure_time = DateTimeField()
|
||||
last_failure_reason = StringField()
|
||||
status = EnumField(DeletionStatus, default=DeletionStatus.created)
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine import Document, StringField, DynamicField
|
||||
from mongoengine import Document, StringField, DynamicField, DateTimeField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
@@ -20,3 +20,4 @@ class User(DbModelMixin, Document):
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = DynamicField(default="", exclude_by_default=True)
|
||||
created = DateTimeField()
|
||||
|
||||
@@ -1,73 +1,16 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
from typing import Sequence, Dict, Callable
|
||||
|
||||
import dpath.path
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.props import PropsMixin
|
||||
|
||||
SEP = "."
|
||||
|
||||
|
||||
def project_dict(data, projection, separator=SEP):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
:param projection: List of dictionary paths (each a string with field names separated using a separator)
|
||||
:param separator: Separator (default is '.')
|
||||
:return: A new dictionary containing only the projected parts from the original dictionary
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
result = {}
|
||||
|
||||
def copy_path(path_parts, source, destination):
|
||||
src, dst = source, destination
|
||||
try:
|
||||
for depth, path_part in enumerate(path_parts[:-1]):
|
||||
src_part = src[path_part]
|
||||
if isinstance(src_part, dict):
|
||||
src = src_part
|
||||
dst = dst.setdefault(path_part, {})
|
||||
elif isinstance(src_part, (list, tuple)):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError(
|
||||
"Incompatible destination type %s for %s (list expected)"
|
||||
% (type(dst), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError(
|
||||
"Destination list length differs from source length for %s"
|
||||
% separator.join(path_parts[: depth + 1])
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported projection type %s for %s"
|
||||
% (type(src), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
except KeyError:
|
||||
# Projection field not in source, no biggie.
|
||||
pass
|
||||
return destination
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator), source=data, destination=result
|
||||
)
|
||||
return result
|
||||
max_items_per_fetch = config.get("services._mongo.max_page_size", 500)
|
||||
|
||||
|
||||
class _ReferenceProxy(dict):
|
||||
@@ -110,9 +53,6 @@ class ProjectionHelper(object):
|
||||
self._ref_projection = None
|
||||
self._proxy_manager = _ProxyManager()
|
||||
|
||||
# Cached dpath paths for each of the result documents
|
||||
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
|
||||
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
@@ -275,25 +215,26 @@ class ProjectionHelper(object):
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
globlist = norm_path.strip(SEP).split(SEP)
|
||||
|
||||
obj_paths = self._cached_results_paths.get(id(obj))
|
||||
if obj_paths is None:
|
||||
obj_paths = self._cached_results_paths[id(obj)] = list(
|
||||
dpath.path.paths(obj, dirs=True, skip=True)
|
||||
)
|
||||
|
||||
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
|
||||
|
||||
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
|
||||
def _search_and_replace(target: dict, p: Sequence[str]) -> Sequence[str]:
|
||||
parent = None
|
||||
target = obj
|
||||
for part in p:
|
||||
parent = target
|
||||
target = target[part[0]]
|
||||
if parent and factory:
|
||||
parent[p[-1][0]] = factory(target)
|
||||
return target
|
||||
for idx, part in enumerate(p):
|
||||
if isinstance(target, dict) and part in target:
|
||||
parent = target
|
||||
target = target[part]
|
||||
elif isinstance(target, list) and part == "*":
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_search_and_replace(t, p[idx + 1 :]) for t in target
|
||||
)
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return [search_and_replace(p) for p in paths]
|
||||
if parent and factory:
|
||||
parent[p[-1]] = factory(target)
|
||||
return [target]
|
||||
|
||||
return _search_and_replace(obj, globlist)
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
@@ -341,10 +282,11 @@ class ProjectionHelper(object):
|
||||
doc_only = list(filter(None, data["only"]))
|
||||
doc_only = list({"id"} | set(doc_only)) if doc_only else None
|
||||
|
||||
for res in projection_func(
|
||||
doc_type=doc_type, projection=doc_only, ids=ids
|
||||
):
|
||||
self._proxy_manager.update(res)
|
||||
for ids_chunk in iterutils.chunked_iter(ids, max_items_per_fetch):
|
||||
for res in projection_func(
|
||||
doc_type=doc_type, projection=doc_only, ids=ids_chunk
|
||||
):
|
||||
self._proxy_manager.update(res)
|
||||
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from collections import OrderedDict, defaultdict
|
||||
from itertools import chain
|
||||
from collections import OrderedDict
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document, BaseField
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocument,
|
||||
Document,
|
||||
)
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
@@ -21,11 +25,18 @@ class PropsMixin(object):
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
__cached_field_names_per_type = None
|
||||
__cached_all_fields_with_instance = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
_document_classes = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, (Document, EmbeddedDocument)):
|
||||
PropsMixin._document_classes[cls._class_name] = cls
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
@@ -33,37 +44,12 @@ class PropsMixin(object):
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_field_names_for_type(cls, of_type=BaseField):
|
||||
"""
|
||||
Return field names per type including subfields
|
||||
The fields of derived types are also returned
|
||||
"""
|
||||
assert issubclass(of_type, BaseField)
|
||||
if cls.__cached_field_names_per_type is None:
|
||||
fields = defaultdict(list)
|
||||
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
||||
fields[type(field)].append(name)
|
||||
for type_ in fields:
|
||||
fields[type_].extend(
|
||||
chain.from_iterable(
|
||||
fields[other_type]
|
||||
for other_type in fields
|
||||
if other_type != type_ and issubclass(other_type, type_)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type = fields
|
||||
|
||||
if of_type not in cls.__cached_field_names_per_type:
|
||||
names = list(
|
||||
chain.from_iterable(
|
||||
field_names
|
||||
for type_, field_names in cls.__cached_field_names_per_type.items()
|
||||
if issubclass(type_, of_type)
|
||||
)
|
||||
def get_all_fields_with_instance(cls):
|
||||
if cls.__cached_all_fields_with_instance is None:
|
||||
cls.__cached_all_fields_with_instance = get_fields(
|
||||
cls, return_instance=True, subfields=True
|
||||
)
|
||||
cls.__cached_field_names_per_type[of_type] = names
|
||||
|
||||
return cls.__cached_field_names_per_type[of_type]
|
||||
return cls.__cached_all_fields_with_instance
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
@@ -83,8 +69,14 @@ class PropsMixin(object):
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
|
||||
if v == "self":
|
||||
return cls_.owner_document
|
||||
|
||||
doc_cls = PropsMixin._document_classes.get(v)
|
||||
if doc_cls:
|
||||
return doc_cls
|
||||
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
@@ -98,7 +90,7 @@ class PropsMixin(object):
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
".".join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
@@ -106,10 +98,10 @@ class PropsMixin(object):
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
|
||||
|
||||
return fields
|
||||
|
||||
@@ -120,7 +112,7 @@ class PropsMixin(object):
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
@@ -129,7 +121,7 @@ class PropsMixin(object):
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
raise ValueError("Invalid field path %s" % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
@@ -145,7 +137,7 @@ class PropsMixin(object):
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
translated_parts.append("*")
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
@@ -154,7 +146,7 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
fields = cls._get_fields_with_attr(cls, "reference_field")
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@@ -169,12 +161,12 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
def get_dpath_translated_path(cls, path, separator="."):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any, Mapping
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
@@ -203,18 +203,22 @@ def _names_set(*names: str) -> Set[str]:
|
||||
return set(names) | set(f"-{name}" for name in names)
|
||||
|
||||
|
||||
system_tag_names = {
|
||||
_system_tag_names = {
|
||||
"model": _names_set("active", "archived"),
|
||||
"project": _names_set("archived", "public", "default"),
|
||||
"task": _names_set("active", "archived", "development"),
|
||||
"queue": _names_set("default"),
|
||||
}
|
||||
|
||||
system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
_system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
|
||||
|
||||
def partition_tags(
|
||||
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
|
||||
entity: str,
|
||||
tags: Sequence[str],
|
||||
system_tags: Optional[Sequence[str]] = (),
|
||||
system_tag_names: Mapping = _system_tag_names,
|
||||
system_tag_prefixes: Mapping = _system_tag_prefixes,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Partition the given tags sequence into system and user-defined tags
|
||||
|
||||
19
apiserver/documentation/api_versions.md
Normal file
19
apiserver/documentation/api_versions.md
Normal file
@@ -0,0 +1,19 @@
|
||||
### Supported api versions
|
||||
|
||||
| Release | ApiVersion |
|
||||
|---------|------------|
|
||||
| v1.13 | 2.27 |
|
||||
| v1.12 | 2.26 |
|
||||
| v1.11 | 2.25 |
|
||||
| v1.10 | 2.24 |
|
||||
| v1.9 | 2.23 |
|
||||
| v1.8 | 2.22 |
|
||||
| v1.7 | 2.21 |
|
||||
| v1.6 | 2.20 |
|
||||
| v1.5 | 2.19 |
|
||||
| v1.4 | 2.18 |
|
||||
| v1.3 | 2.17 |
|
||||
| v1.2 | 2.16 |
|
||||
| v1.1 | 2.15 |
|
||||
| v1.0 | 2.14 |
|
||||
| v0.17 | 2.13 |
|
||||
@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
@@ -21,7 +21,7 @@ def apply_mappings_to_cluster(
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
template_name = f.stem
|
||||
res = es.indices.put_template(template_name, body=data)
|
||||
res = es.indices.put_template(name=template_name, body=data)
|
||||
return {"mapping": template_name, "result": res}
|
||||
|
||||
p = HERE / "mappings"
|
||||
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
|
||||
@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
es = Elasticsearch(
|
||||
hosts=cluster_conf.get("hosts", None),
|
||||
http_auth=es_factory.get_credentials("events", cluster_conf),
|
||||
**cluster_conf.get("args", {}),
|
||||
)
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
@@ -109,5 +113,9 @@ def init_es_data():
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
http_auth = es_factory.get_credentials(name)
|
||||
|
||||
res = apply_mappings_to_cluster(
|
||||
hosts_config, name, es_args=args, http_auth=http_auth
|
||||
)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_replicas": 0,
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
@@ -34,6 +35,12 @@
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"model_event": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_replicas": 0,
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
@@ -19,6 +20,9 @@
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_replicas": 0,
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
@@ -31,6 +32,9 @@
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,30 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_ELASTIC_SERVICE_HOST",
|
||||
"TRAINS_ELASTIC_SERVICE_HOST",
|
||||
"ELASTIC_SERVICE_HOST",
|
||||
"ELASTIC_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_ELASTIC_SERVICE_PORT", "ELASTIC_SERVICE_PORT")
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_ELASTIC_SERVICE_PORT",
|
||||
"TRAINS_ELASTIC_SERVICE_PORT",
|
||||
"ELASTIC_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
|
||||
|
||||
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
@@ -23,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
|
||||
if OVERRIDE_USERNAME:
|
||||
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
|
||||
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
if OVERRIDE_PASSWORD:
|
||||
log.info("Using override elastic password ********")
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
@@ -42,9 +61,13 @@ class InvalidClusterConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingPasswordForElasticUser(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ESFactory:
|
||||
@classmethod
|
||||
def connect(cls, cluster_name):
|
||||
def connect(cls, cluster_name) -> Elasticsearch:
|
||||
"""
|
||||
Returns the es client for the cluster.
|
||||
Connects to the cluster if did not connect previously
|
||||
@@ -59,18 +82,45 @@ class ESFactory:
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
|
||||
http_auth = cls.get_credentials(cluster_name)
|
||||
|
||||
args = cluster_config.get("args", {})
|
||||
_instances[cluster_name] = Elasticsearch(
|
||||
hosts=hosts, transport_class=Transport, **args
|
||||
hosts=hosts, http_auth=http_auth, **args
|
||||
)
|
||||
|
||||
return _instances[cluster_name]
|
||||
|
||||
@classmethod
|
||||
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
|
||||
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
|
||||
if not cluster_config.get("secure", True):
|
||||
return None
|
||||
|
||||
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
||||
if not elastic_user:
|
||||
return None
|
||||
|
||||
elastic_password = OVERRIDE_PASSWORD or config.get(
|
||||
"secure.elastic.password", None
|
||||
)
|
||||
if not elastic_password:
|
||||
raise MissingPasswordForElasticUser(
|
||||
f"cluster={cluster_name}, username={elastic_user}"
|
||||
)
|
||||
|
||||
return elastic_user, elastic_password
|
||||
|
||||
@classmethod
|
||||
def get_all_cluster_names(cls):
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
@classmethod
|
||||
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_cluster_config(cls, cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
@@ -84,14 +134,16 @@ class ESFactory:
|
||||
raise MissingClusterConfiguration(cluster_name)
|
||||
|
||||
def set_host_prop(key, value):
|
||||
for host in cluster_config.get("hosts", []):
|
||||
host[key] = value
|
||||
for entry in cluster_config.get("hosts", []):
|
||||
entry[key] = value
|
||||
|
||||
if OVERRIDE_HOST:
|
||||
set_host_prop("host", OVERRIDE_HOST)
|
||||
host, port = cls.get_override_host(cluster_name)
|
||||
|
||||
if OVERRIDE_PORT:
|
||||
set_host_prop("port", OVERRIDE_PORT)
|
||||
if host:
|
||||
set_host_prop("host", host)
|
||||
|
||||
if port:
|
||||
set_host_prop("port", port)
|
||||
|
||||
return cluster_config
|
||||
|
||||
@@ -120,7 +172,9 @@ class ESFactory:
|
||||
@classmethod
|
||||
def get_es_timestamp_str(cls):
|
||||
now = datetime.utcnow()
|
||||
return now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
|
||||
return (
|
||||
now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
|
||||
)
|
||||
|
||||
|
||||
es_factory = ESFactory()
|
||||
|
||||
611
apiserver/jobs/async_urls_delete.py
Normal file
611
apiserver/jobs/async_urls_delete.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import os
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional, Tuple, Mapping, TypeVar, Hashable, Generic
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from azure.storage.blob import ContainerClient, PartialBatchErrorException
|
||||
from boltons.iterutils import bucketize, chunked_iter
|
||||
from furl import furl
|
||||
from google.cloud import storage as google_storage
|
||||
from mongoengine import Q
|
||||
from mypy_boto3_s3.service_resource import Bucket as AWSBucket
|
||||
|
||||
from apiserver.bll.storage import StorageBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import db
|
||||
from apiserver.database.model.url_to_delete import UrlToDelete, StorageType, DeletionStatus
|
||||
|
||||
log = config.logger(f"JOB-{Path(__file__).name}")
|
||||
conf = config.get("services.async_urls_delete")
|
||||
max_retries = conf.get("max_retries", 3)
|
||||
retry_timeout = timedelta(seconds=conf.get("retry_timeout_sec", 60))
|
||||
storage_bll = StorageBLL()
|
||||
|
||||
|
||||
def mark_retry_failed(ids: Sequence[str], reason: str):
|
||||
UrlToDelete.objects(id__in=ids).update(
|
||||
last_failure_time=datetime.utcnow(),
|
||||
last_failure_reason=reason,
|
||||
inc__retry_count=1,
|
||||
)
|
||||
UrlToDelete.objects(id__in=ids, retry_count__gte=max_retries).update(
|
||||
status=DeletionStatus.failed
|
||||
)
|
||||
|
||||
|
||||
def mark_failed(query: Q, reason: str):
|
||||
UrlToDelete.objects(query).update(
|
||||
status=DeletionStatus.failed,
|
||||
last_failure_time=datetime.utcnow(),
|
||||
last_failure_reason=reason,
|
||||
)
|
||||
|
||||
|
||||
def scheme_prefix(scheme: str) -> str:
|
||||
return str(furl(scheme=scheme, netloc=""))
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Hashable)
|
||||
|
||||
|
||||
class Storage(Generic[T], metaclass=ABCMeta):
|
||||
class Client(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def chunk_size(self) -> int:
|
||||
pass
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
pass
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
pass
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[T, Sequence[UrlToDelete]]:
|
||||
pass
|
||||
|
||||
def get_client(self, base: T, urls: Sequence[UrlToDelete]) -> Client:
|
||||
pass
|
||||
|
||||
|
||||
def delete_urls(urls_query: Q, storage: Storage):
|
||||
to_delete = list(UrlToDelete.objects(urls_query).order_by("url").limit(10000))
|
||||
if not to_delete:
|
||||
return
|
||||
|
||||
grouped_urls = storage.group_urls(to_delete)
|
||||
for base, urls in grouped_urls.items():
|
||||
if not base:
|
||||
msg = f"Invalid {storage.name} url or missing {storage.name} configuration for account"
|
||||
mark_failed(
|
||||
Q(id__in=[url.id for url in urls]), msg,
|
||||
)
|
||||
log.warning(
|
||||
f"Failed to delete {len(urls)} files from {storage.name} due to: {msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
client = storage.get_client(base, urls)
|
||||
except Exception as ex:
|
||||
failed = [url.id for url in urls]
|
||||
mark_retry_failed(failed, reason=str(ex))
|
||||
log.warning(
|
||||
f"Failed to delete {len(failed)} files from {storage.name} due to: {str(ex)}"
|
||||
)
|
||||
continue
|
||||
|
||||
for chunk in chunked_iter(urls, client.chunk_size):
|
||||
paths = []
|
||||
path_to_id_mapping = defaultdict(list)
|
||||
ids_to_delete = set()
|
||||
for url in chunk:
|
||||
try:
|
||||
path = client.get_path(url)
|
||||
except Exception as ex:
|
||||
err = str(ex)
|
||||
mark_failed(Q(id=url.id), err)
|
||||
log.warning(f"Error getting path for {url.url}: {err}")
|
||||
continue
|
||||
|
||||
paths.append(path)
|
||||
path_to_id_mapping[path].append(url.id)
|
||||
ids_to_delete.add(url.id)
|
||||
|
||||
if not paths:
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted_paths, errors = client.delete_many(paths)
|
||||
except Exception as ex:
|
||||
mark_retry_failed([url.id for url in urls], str(ex))
|
||||
log.warning(
|
||||
f"Error deleting {len(paths)} files from {storage.name}: {str(ex)}"
|
||||
)
|
||||
continue
|
||||
|
||||
failed_ids = set()
|
||||
for reason, err_paths in errors.items():
|
||||
error_ids = set(
|
||||
chain.from_iterable(
|
||||
path_to_id_mapping.get(p, []) for p in err_paths
|
||||
)
|
||||
)
|
||||
mark_retry_failed(list(error_ids), reason)
|
||||
log.warning(
|
||||
f"Failed to delete {len(error_ids)} files from {storage.name} storage due to: {reason}"
|
||||
)
|
||||
failed_ids.update(error_ids)
|
||||
|
||||
deleted_ids = set(
|
||||
chain.from_iterable(
|
||||
path_to_id_mapping.get(p, []) for p in deleted_paths
|
||||
)
|
||||
)
|
||||
if deleted_ids:
|
||||
UrlToDelete.objects(id__in=list(deleted_ids)).delete()
|
||||
log.info(
|
||||
f"{len(deleted_ids)} files deleted from {storage.name} storage"
|
||||
)
|
||||
|
||||
missing_ids = ids_to_delete - deleted_ids - failed_ids
|
||||
if missing_ids:
|
||||
mark_retry_failed(list(missing_ids), "Not succeeded")
|
||||
|
||||
|
||||
class FileserverStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
timeout = conf.get("fileserver.timeout_sec", 300)
|
||||
|
||||
def __init__(self, session: requests.Session, host: str):
|
||||
self.session = session
|
||||
self.delete_url = furl(host).add(path="delete_many").url
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 10000
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
path = url.url.strip("/")
|
||||
if not path:
|
||||
raise ValueError("Empty path")
|
||||
|
||||
return path
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
res = self.session.post(
|
||||
url=self.delete_url, json={"files": list(paths)}, timeout=self.timeout
|
||||
)
|
||||
res.raise_for_status()
|
||||
res_data = res.json()
|
||||
return list(res_data.get("deleted", {})), res_data.get("errors", {})
|
||||
|
||||
def __init__(self, company: str, fileserver_host: str = None):
|
||||
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
|
||||
self.host = fileserver_host.rstrip("/")
|
||||
if not self.host:
|
||||
log.warning(f"Fileserver host not configured")
|
||||
|
||||
def _parse_url_prefix(prefix) -> Tuple[str, str]:
|
||||
url = furl(prefix)
|
||||
host = f"{url.scheme}://{url.netloc}" if url.scheme else None
|
||||
return host, str(url.path).rstrip("/")
|
||||
|
||||
url_prefixes = [
|
||||
_parse_url_prefix(p) for p in conf.get("fileserver.url_prefixes", [])
|
||||
]
|
||||
if not any(self.host == host for host, _ in url_prefixes):
|
||||
url_prefixes.append((self.host, ""))
|
||||
self.url_prefixes = url_prefixes
|
||||
|
||||
self.company = company
|
||||
|
||||
# @classmethod
|
||||
# def validate_fileserver_access(cls, fileserver_host: str):
|
||||
# res = requests.get(
|
||||
# url=fileserver_host
|
||||
# )
|
||||
# res.raise_for_status()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Fileserver"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
if not url.url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = furl(url.url)
|
||||
url_host = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme else None
|
||||
url_path = str(parsed.path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for host, path_prefix in self.url_prefixes:
|
||||
if host and url_host != host:
|
||||
continue
|
||||
if path_prefix and not url_path.startswith(path_prefix + "/"):
|
||||
continue
|
||||
url.url = url_path[len(path_prefix or "") :]
|
||||
return self.host
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[str, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
|
||||
host = base
|
||||
session = requests.session()
|
||||
res = session.get(url=host, timeout=self.Client.timeout)
|
||||
res.raise_for_status()
|
||||
|
||||
return self.Client(session, host)
|
||||
|
||||
|
||||
class AzureStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
def __init__(self, container: ContainerClient):
|
||||
self.container = container
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 256
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
parsed = furl(url.url)
|
||||
if (
|
||||
not parsed.path
|
||||
or not parsed.path.segments
|
||||
or len(parsed.path.segments) <= 1
|
||||
):
|
||||
raise ValueError("No path found following container name")
|
||||
|
||||
return os.path.join(*parsed.path.segments[1:])
|
||||
|
||||
@staticmethod
|
||||
def _path_from_request_url(request_url: str) -> str:
|
||||
try:
|
||||
return furl(request_url).path.segments[-1]
|
||||
except Exception:
|
||||
return request_url
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
try:
|
||||
res = self.container.delete_blobs(*paths)
|
||||
except PartialBatchErrorException as pex:
|
||||
deleted = []
|
||||
errors = defaultdict(list)
|
||||
for part in pex.parts:
|
||||
if 300 >= part.status_code >= 200:
|
||||
deleted.append(self._path_from_request_url(part.request.url))
|
||||
else:
|
||||
errors[part.reason].append(
|
||||
self._path_from_request_url(part.request.url)
|
||||
)
|
||||
return deleted, errors
|
||||
|
||||
return [self._path_from_request_url(part.request.url) for part in res], {}
|
||||
|
||||
def __init__(self, company: str):
|
||||
self.configs = storage_bll.get_azure_settings_for_company(company)
|
||||
self.scheme = "azure"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Azure"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[Tuple]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url.url)
|
||||
if parsed.scheme != self.scheme:
|
||||
return None
|
||||
|
||||
azure_conf = self.configs.get_config_by_uri(url.url)
|
||||
if azure_conf is None:
|
||||
return None
|
||||
|
||||
account_url = parsed.netloc
|
||||
return account_url, azure_conf.container_name
|
||||
except Exception as ex:
|
||||
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
|
||||
return None
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[Tuple, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: Tuple, urls: Sequence[UrlToDelete]) -> Client:
|
||||
account_url, container_name = base
|
||||
sample_url = urls[0].url
|
||||
cfg = self.configs.get_config_by_uri(sample_url)
|
||||
if not cfg or not cfg.account_name or not cfg.account_key:
|
||||
raise ValueError(
|
||||
f"Missing account name or key for Azure Blob Storage "
|
||||
f"account: {account_url}, container: {container_name}"
|
||||
)
|
||||
|
||||
return self.Client(
|
||||
ContainerClient(
|
||||
account_url=account_url,
|
||||
container_name=cfg.container_name,
|
||||
credential={
|
||||
"account_name": cfg.account_name,
|
||||
"account_key": cfg.account_key,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AWSStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
def __init__(self, base_url: str, container: AWSBucket):
|
||||
self.container = container
|
||||
self.base_url = base_url
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 1000
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
""" Normalize remote path. Remove any prefix that is already handled by the container """
|
||||
path = url.url
|
||||
if path.startswith(self.base_url):
|
||||
path = path[len(self.base_url) :]
|
||||
path = path.lstrip("/")
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def _path_from_request_url(request_url: str) -> str:
|
||||
try:
|
||||
return furl(request_url).path.segments[-1]
|
||||
except Exception:
|
||||
return request_url
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
res = self.container.delete_objects(
|
||||
Delete={"Objects": [{"Key": p} for p in paths]}
|
||||
)
|
||||
errors = defaultdict(list)
|
||||
for err in res.get("Errors", []):
|
||||
msg = err.get("Message", "")
|
||||
errors[msg].append(err.get("Key"))
|
||||
|
||||
return [d.get("Key") for d in res.get("Deleted", [])], errors
|
||||
|
||||
def __init__(self, company: str):
|
||||
self.configs = storage_bll.get_aws_settings_for_company(company)
|
||||
self.scheme = "s3"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "AWS"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url.url)
|
||||
if parsed.scheme != self.scheme:
|
||||
return None
|
||||
|
||||
s3_conf = self.configs.get_config_by_uri(url.url)
|
||||
if s3_conf is None:
|
||||
return None
|
||||
|
||||
s3_bucket = s3_conf.bucket
|
||||
if not s3_bucket:
|
||||
parts = Path(parsed.path.strip("/")).parts
|
||||
if parts:
|
||||
s3_bucket = parts[0]
|
||||
return "/".join(filter(None, ("s3:/", s3_conf.host, s3_bucket)))
|
||||
except Exception as ex:
|
||||
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
|
||||
return None
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[str, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
|
||||
sample_url = urls[0].url
|
||||
cfg = self.configs.get_config_by_uri(sample_url)
|
||||
boto_kwargs = {
|
||||
"endpoint_url": (("https://" if cfg.secure else "http://") + cfg.host)
|
||||
if cfg.host
|
||||
else None,
|
||||
"use_ssl": cfg.secure,
|
||||
"verify": cfg.verify,
|
||||
}
|
||||
name = base[len(scheme_prefix(self.scheme)) :]
|
||||
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name
|
||||
if not cfg.use_credentials_chain:
|
||||
if not cfg.key or not cfg.secret:
|
||||
raise ValueError(
|
||||
f"Missing key or secret for AWS S3 host: {cfg.host}, bucket: {str(bucket_name)}"
|
||||
)
|
||||
|
||||
boto_kwargs["aws_access_key_id"] = cfg.key
|
||||
boto_kwargs["aws_secret_access_key"] = cfg.secret
|
||||
if cfg.token:
|
||||
boto_kwargs["aws_session_token"] = cfg.token
|
||||
|
||||
return self.Client(
|
||||
base, boto3.resource("s3", **boto_kwargs).Bucket(bucket_name)
|
||||
)
|
||||
|
||||
|
||||
class GoogleCloudStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
def __init__(self, base_url: str, container: google_storage.Bucket):
|
||||
self.container = container
|
||||
self.base_url = base_url
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 100
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
""" Normalize remote path. Remove any prefix that is already handled by the container """
|
||||
path = url.url
|
||||
if path.startswith(self.base_url):
|
||||
path = path[len(self.base_url) :]
|
||||
path = path.lstrip("/")
|
||||
return path
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
not_found = set()
|
||||
|
||||
def error_callback(blob: google_storage.Blob):
|
||||
not_found.add(blob.name)
|
||||
|
||||
self.container.delete_blobs(
|
||||
[self.container.blob(p) for p in paths], on_error=error_callback,
|
||||
)
|
||||
errors = {"Not found": list(not_found)} if not_found else {}
|
||||
return list(set(paths) - not_found), errors
|
||||
|
||||
def __init__(self, company: str):
|
||||
self.configs = storage_bll.get_gs_settings_for_company(company)
|
||||
self.scheme = "gs"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Google Storage"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url.url)
|
||||
if parsed.scheme != self.scheme:
|
||||
return None
|
||||
|
||||
gs_conf = self.configs.get_config_by_uri(url.url)
|
||||
if gs_conf is None:
|
||||
return None
|
||||
|
||||
return str(furl(scheme=parsed.scheme, netloc=gs_conf.bucket))
|
||||
except Exception as ex:
|
||||
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
|
||||
return None
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[str, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
|
||||
sample_url = urls[0].url
|
||||
cfg = self.configs.get_config_by_uri(sample_url)
|
||||
if cfg.credentials_json:
|
||||
from google.oauth2 import service_account
|
||||
|
||||
credentials = service_account.Credentials.from_service_account_file(
|
||||
cfg.credentials_json
|
||||
)
|
||||
else:
|
||||
credentials = None
|
||||
|
||||
bucket_name = base[len(scheme_prefix(self.scheme)) :]
|
||||
return self.Client(
|
||||
base,
|
||||
google_storage.Client(project=cfg.project, credentials=credentials).bucket(
|
||||
bucket_name
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_delete_loop(fileserver_host: str):
|
||||
storage_helpers = {
|
||||
StorageType.fileserver: partial(
|
||||
FileserverStorage, fileserver_host=fileserver_host
|
||||
),
|
||||
StorageType.s3: AWSStorage,
|
||||
StorageType.azure: AzureStorage,
|
||||
StorageType.gs: GoogleCloudStorage,
|
||||
}
|
||||
while True:
|
||||
now = datetime.utcnow()
|
||||
urls_query = (
|
||||
Q(status__ne=DeletionStatus.failed)
|
||||
& Q(retry_count__lt=max_retries)
|
||||
& (
|
||||
Q(last_failure_time__exists=False)
|
||||
| Q(last_failure_time__lt=now - retry_timeout)
|
||||
)
|
||||
)
|
||||
|
||||
url_to_delete: UrlToDelete = UrlToDelete.objects(
|
||||
urls_query & Q(storage_type__in=list(storage_helpers))
|
||||
).order_by("retry_count").limit(1).first()
|
||||
if not url_to_delete:
|
||||
sleep(10)
|
||||
continue
|
||||
|
||||
company = url_to_delete.company
|
||||
user = url_to_delete.user
|
||||
storage_type = url_to_delete.storage_type
|
||||
log.info(
|
||||
f"Deleting {storage_type} objects for company: {company}, user: {user}"
|
||||
)
|
||||
company_storage_urls_query = urls_query & Q(
|
||||
company=company, storage_type=storage_type,
|
||||
)
|
||||
delete_urls(
|
||||
urls_query=company_storage_urls_query,
|
||||
storage=storage_helpers[storage_type](company=company),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
|
||||
parser.add_argument(
|
||||
"--fileserver-host", "-fh", help="Fileserver host address", type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
db.initialize()
|
||||
run_delete_loop(args.fileserver_host)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user