Compare commits

177 Commits

Author SHA1 Message Date
allegroai
69737308fe Version bump to v1.4.0 2022-04-18 16:38:22 +03:00
allegroai
a6dbea808a Add indices for task.last_update and task.status_changed 2022-04-18 16:37:22 +03:00
allegroai
5131b17901 Support not returning hidden sub-projects when include_stats is specified without search_hidden 2022-04-18 16:36:14 +03:00
allegroai
5f21c3a56d Add support for searching hidden projects and tasks 2022-04-18 16:34:18 +03:00
allegroai
2350ac64ed Fix internal error on count task events if there is no events index 2022-04-18 16:31:02 +03:00
allegroai
d146127c18 Add events.clear_scroll endpoint to clear event search scrolls 2022-04-18 16:29:57 +03:00
Mal Miller
abd65e103e Ensure agent-services waits for API server to be ready (#129) 2022-03-31 11:10:45 +03:00
pollfly
bf65ea7bd0 Resize admonitions (#126) 2022-03-27 15:04:43 +03:00
pollfly
73e278a8ed Add deprecation notes to legacy docs (#124) 2022-03-23 23:51:55 +02:00
Zied ANDOLSI
d92dfbbdb7 Allow ClearML to be served with a URL path prefix (#121)
* add server root url

* [Feature Request] Add proxy_pass for root url other than /

* [Feature Request] Add proxy_pass for root url other than /

* add support for web sub path

* add support for web sub path

* use default conf instead of created a custom one

* code reivew: move cp command in if block

* Add commented env var in the docker-compose file

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-22 17:21:58 +02:00
Zied ANDOLSI
5c1e419eb5 Allow overriding clearml web git url on build (#122)
* add server root url

* [Feature Request] Add possibility to override clearml web git url

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-17 14:35:50 +02:00
allegroai
124684f53f Version bump to v1.3.0 2022-03-15 16:34:35 +02:00
allegroai
455b5d6758 Fix pre-populate to convert model metadata from the old format 2022-03-15 16:30:14 +02:00
allegroai
c04e2e498b Support credentials label and last_used_from fields 2022-03-15 16:29:37 +02:00
allegroai
da8a45072f Add pipelines support 2022-03-15 16:28:59 +02:00
allegroai
e1992e2054 Fix queue metrics calculation 2022-03-15 16:28:49 +02:00
allegroai
c17cedd93a Support disabling response compression in fileserver 2022-03-15 16:27:31 +02:00
allegroai
b6ad8f8790 Add support for worker auto-unregister (instead of raising an error) 2022-03-15 16:25:14 +02:00
allegroai
5acc7eebc3 Set API version to 2.17 2022-03-15 16:22:51 +02:00
allegroai
941927dfcd Return fixed fileserver header 2022-03-15 16:21:52 +02:00
allegroai
02933a9c93 Support disabling response compression
Return fixed server header
2022-03-15 16:21:14 +02:00
allegroai
e537651f29 Better support for assets upload/download 2022-03-15 16:19:52 +02:00
allegroai
af09fba755 Add metadata dict support for models, queues
Add more info for projects
2022-03-15 16:18:57 +02:00
Reuben Morais
04ea9018a3 Add missing g++ dep to server build (#111) 2022-02-21 22:14:22 +02:00
allegroai
ff7e1be24f Updated docker-compose files for v1.2.0 2022-02-14 15:27:23 +02:00
allegroai
fc4fd9e61c Version bump to v1.2.0 2022-02-14 15:26:27 +02:00
allegroai
8908c7dcf9 Update driver requirements
Refactor ES initialization
2022-02-13 20:27:12 +02:00
allegroai
b9996e2c1a Protect against multiple connects to the update server from different processes
Code cleanup
2022-02-13 20:12:12 +02:00
allegroai
afdc56f37c Use task active duration for worker task running time 2022-02-13 20:01:47 +02:00
allegroai
a25cd5dae8 Fix version conflicts when deleting task events cause an error 2022-02-13 20:01:25 +02:00
allegroai
447adb9090 Add support for credentials label
Support no_scroll in events.get_task_plots
Support better project stats
Fix Redis required on mongodb initialization
Update tests
2022-02-13 19:59:58 +02:00
allegroai
92fd98d5ad Add support for lists and nested fields in URL args and form 2022-02-13 19:52:05 +02:00
allegroai
c4001b4037 Add Redis cluster support
Fix for lru_cache usage
2022-02-13 19:48:26 +02:00
allegroai
970a32287a Add Redis password support 2022-02-13 19:37:52 +02:00
allegroai
17cd48dada Add support for override cookie domains
Support for community invitation alarms
Remove duplicate property
Add query optimizations
2022-02-13 19:35:35 +02:00
allegroai
ea3b6e955f Optimize nested_get() 2022-02-13 19:32:22 +02:00
allegroai
843450bb9b Fix add_or_update_artifacts should always be allowed on in_progress tasks
Fix delete_artifacts should always be allowed on in_progress tasks
Fix query code
2022-02-13 19:31:54 +02:00
allegroai
e149af58b1 Support for additional mata data in api call response 2022-02-13 19:30:36 +02:00
allegroai
604a38035b Add organization.update_company_name
Fix unit-tests
2022-02-13 19:29:46 +02:00
allegroai
cae38a365b Fix base query building
Fix schema
Improve events.scalar_metrics_iter_raw implementation
2022-02-13 19:28:23 +02:00
allegroai
e334246b46 Add support for project stats with children flag 2022-02-13 19:26:47 +02:00
allegroai
36e013b40c Add support for events.scalar_metrics_iter_raw 2022-02-13 19:26:03 +02:00
allegroai
f20cd6536e Add scroll support to *.get_* 2022-02-13 19:23:29 +02:00
allegroai
446bd35006 Refactor debug images response, model ORM 2022-02-13 19:21:07 +02:00
allegroai
a377a7e315 Support status_message and status_reason in tasks.delete 2022-02-13 19:20:31 +02:00
allegroai
3d046ac282 Fix project should not be merged into itself 2022-02-13 19:18:08 +02:00
allegroai
a08fa9a0e1 Add missing API Errors 2022-02-13 19:16:58 +02:00
allegroai
5856ed2836 Update Model.last_update on changes to tags and system tags 2022-02-13 19:15:37 +02:00
allegroai
d295355d99 Better logger name if called from __init__.py 2022-02-13 19:15:10 +02:00
pollfly
77350f6119 Fix link (#104) 2022-01-27 12:15:55 +02:00
Niels ten Boom
bc2c2ebbfd Add connection string functionality for MongoDB access (#102) 2022-01-08 12:07:59 +02:00
allegroai
1502e02a1a Update ES version to 7.16.2 2021-12-22 13:53:34 +02:00
allegroai
d0e2313a24 Update README regarding CVE-2021-45046 2021-12-15 15:51:18 +02:00
allegroai
d8ba1a8ea7 Fix README 2021-12-14 15:52:53 +02:00
allegroai
ca7937fc4e Fix README 2021-12-14 15:50:30 +02:00
allegroai
df89bcceef Update README with a note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:48:54 +02:00
allegroai
cfccbe05c1 Add precautionary mitigation for Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:15:11 +02:00
Théo Mathieu
e352a6a1e7 Fix elasticsearch authentication when initializing (#98) 2021-12-05 09:55:06 +02:00
Théo Mathieu
8a3d992aaf Support MongoDB SRV endpoints (#96) 2021-12-02 10:07:33 +02:00
allegroai
c37f3d8d5b Fix set() not supported in ConfigTree()
Add user/pass config support
2021-11-15 18:33:49 +02:00
allegroai
a96870e092 Add admonition in case only username or password were provided 2021-11-15 15:19:07 +02:00
allegroai
6bf1032237 Rename back to docker-compose.yml 2021-11-15 15:13:09 +02:00
Weixiao Huang
3d816c747d Add ES http_auth credentials support (#93)
Also update ES and MongoDB versions and fix nginx configuration bug

Co-authored-by: huangweixiao <huangweixiao@megvii.com>
2021-11-15 15:01:27 +02:00
Jake Henning
3f2b96266b Merge pull request #91 from valeriano-manassero/fix-dockerfile-chmod
Fix chmod for file copy in Dockerfile
2021-10-19 11:35:00 +03:00
Valeriano Manassero
22b16d12eb fix chmod for file copy 2021-10-19 09:06:51 +02:00
allegroai
c55b6f30df Add Dockerfile 2021-10-18 16:52:17 +03:00
allegroai
b7045d3d28 Fix docker-compose escaping 2021-10-18 16:49:51 +03:00
Jake Henning
e31a404885 Remove README mentions of demo server (#90) 2021-10-10 16:32:51 +03:00
Revital
643588b71a edit README mention of demo server 2021-10-10 11:27:44 +03:00
Jake Henning
a64c4d264d Merge pull request #82 from IgorKasianenko/IgorKasianenko-patch-1
Fix typo TRAINS > CLEARML for env variables in README
2021-08-12 11:47:20 +03:00
Igor Kasianenko
567780e188 Fix typo TRAINS > CLEARML for env variables 2021-08-11 16:21:02 +03:00
allegroai
1bc8529d83 Version bump 2021-08-05 16:46:29 +03:00
allegroai
6b480d7e87 Fix file server GET response for gzipped data-files contains Content-Encoding: gz header, causing clients to automatically decompress the file 2021-08-05 16:46:25 +03:00
allegroai
083fd315e9 Fix server error when running with non-migrated v0.16 ElasticSearch data 2021-08-05 16:46:05 +03:00
Jake Henning
ef20e76174 Update README with artifact.io badge 2021-07-27 19:53:41 +03:00
Jake Henning
8c8910808e Merge pull request #80 from pollfly/master
Fix README links
2021-07-27 12:58:45 +03:00
Revital
f6ad379310 link to clear.ml docs in readme, add image 2021-07-27 12:54:41 +03:00
allegroai
c5d6ce3e65 Version bump 2021-07-25 14:40:57 +03:00
allegroai
694dbc31c4 Fix incorrect ES query (merge issue) 2021-07-25 14:40:49 +03:00
allegroai
6488dc54e6 Better handling of stack trace report on 500 error 2021-07-25 14:39:59 +03:00
allegroai
158da9b480 Allow setting status_message in tasks.update
Optimizations and refactoring
2021-07-25 14:35:36 +03:00
allegroai
ec2e071ab7 Fix mongoengine cannot handle field name with leading or trailing "_" when used in fields query within get_all endpoints 2021-07-25 14:34:04 +03:00
allegroai
465e270342 Fix queued task is not dequeued on tasks.stop 2021-07-25 14:32:09 +03:00
allegroai
6705aff56f Allow requesting plots and iter_histograms for all variants 2021-07-25 14:30:38 +03:00
allegroai
9069cfe1da Support querying task events per specific metrics and variants 2021-07-25 14:29:41 +03:00
allegroai
677bb3ba6d Add force parameter to tasks.enqueue 2021-07-25 14:27:46 +03:00
allegroai
cb253cff9e Don't use special characters in secrets 2021-07-25 14:26:49 +03:00
allegroai
39ceb5ac5c Fix pre-populate logic to avoid overriding existing users 2021-07-25 14:26:31 +03:00
allegroai
d4edeaaf1b Add projects.validate_delete 2021-07-25 14:17:29 +03:00
allegroai
56aea1ffb8 Fix filtering on hyperparams (https://github.com/allegroai/clearml/issues/385, https://clearml.slack.com/archives/CTK20V944/p1626600582284700) 2021-07-25 13:55:09 +03:00
allegroai
09ab2af34c Version bump 2021-05-27 17:13:19 +03:00
allegroai
8bb26a6b0b Fix fileserver depends on deprecated flask._compat.fspath and safe_join 2021-05-27 17:13:02 +03:00
allegroai
3f2304549d Move new migrations to 1_0_2 2021-05-27 16:56:47 +03:00
allegroai
ad72a435f1 Clean Task runtime on reset 2021-05-27 16:56:03 +03:00
allegroai
f34332344e Fix Task container raises validation error on null values 2021-05-27 16:55:32 +03:00
allegroai
d324b57dd7 Fix bad error message format 2021-05-27 16:55:00 +03:00
allegroai
2216bfe875 Version bump 2021-05-11 16:12:48 +03:00
allegroai
9beefa7473 Add missing login.logout endpoint 2021-05-11 16:12:27 +03:00
allegroai
8ebc334889 Fix broken config dir backwards compatibility (/opt/trains/config should still be supported) 2021-05-11 16:12:13 +03:00
allegroai
e662c850af Update config file in docs 2021-05-04 11:07:38 +03:00
allegroai
1e5163e530 Upgrade jinja2 version due to CVE-2020-28493 2021-05-03 23:23:06 +03:00
allegroai
1567774765 Version bump 2021-05-03 18:20:32 +03:00
allegroai
babfcbb707 Update migration script 2021-05-03 18:15:43 +03:00
allegroai
027edd86bb Fix actual file path reported in error/success message 2021-05-03 18:14:56 +03:00
allegroai
cc83aadae6 Fix file delete (bad merge) 2021-05-03 18:14:30 +03:00
allegroai
8c18660a82 Fix inconsistency in accessing files between download and delete 2021-05-03 18:14:08 +03:00
allegroai
4fe61ee25c Fix running migration scripts calling other files 2021-05-03 18:13:49 +03:00
allegroai
e18b21639c Fix regex query for fields containing "_" 2021-05-03 18:13:00 +03:00
allegroai
1cef03b8c2 Add check_contents flag for projects.get_all_ex 2021-05-03 18:12:44 +03:00
allegroai
d60d6dfe99 Move to clearml in docker-compose files 2021-05-03 18:12:21 +03:00
allegroai
27d086bca2 Fix schema for Task.runtime
Add infrastructure for API calls limits handling
2021-05-03 18:11:46 +03:00
allegroai
add3f011a0 Add runtime to tasks.edit 2021-05-03 18:10:48 +03:00
allegroai
ee90b0b024 Remove "Auto-generated while cloning" project description 2021-05-03 18:10:32 +03:00
allegroai
9bf107866f Fix crash in models publish_many without model task 2021-05-03 18:10:09 +03:00
allegroai
4d2f282950 Add Model.last_update to schema 2021-05-03 18:09:54 +03:00
allegroai
b55fad1b59 Remove "Auto-generated during move" project description 2021-05-03 18:09:31 +03:00
allegroai
ba77ff11e9 Fix missing custom metric values turn up first in sorting 2021-05-03 18:08:39 +03:00
allegroai
b67aa05d6f Return results per task iterations in debug images request 2021-05-03 18:08:14 +03:00
allegroai
6b0c45a861 Fix batch operations results 2021-05-03 18:07:37 +03:00
allegroai
dc9623e964 Fix docker_cmd projection in backwards compatibility
Fix support to clear input/output models and docker_cmd in backwards compatibility mode
Fix schema
2021-05-03 18:06:39 +03:00
allegroai
3d73d60826 Better handling of invalid iterations on add_batch 2021-05-03 18:05:24 +03:00
allegroai
9f0c9c3690 Fix open ranges 2021-05-03 18:05:03 +03:00
allegroai
1a3d3494ce Fix numeric locale 2021-05-03 18:04:45 +03:00
allegroai
b99f620073 Added unarchive APIs 2021-05-03 18:04:17 +03:00
allegroai
e2f265b4bc Unify batch operations 2021-05-03 18:03:54 +03:00
allegroai
251ee57ffd Fix rapidjson dumps does not support ensure_ascii, only Encoder initialization does
Add task enqueue status
2021-05-03 18:03:17 +03:00
allegroai
7e03104f1c Add Model last_update field 2021-05-03 18:02:25 +03:00
allegroai
f1a258208e Disable backwards compatibility for 2.13 clients 2021-05-03 18:01:59 +03:00
allegroai
66cc49313b Fix schema 2021-05-03 18:01:29 +03:00
allegroai
9ae2943f7d Fix crash in tasks.reset 2021-05-03 17:59:44 +03:00
allegroai
54326f707b Add JSON flags support to APICall 2021-05-03 17:58:57 +03:00
allegroai
3a3b57c15f Support mongodb authentication 2021-05-03 17:57:53 +03:00
allegroai
8ea8ad34e6 Remove collecting task output models from Models collection during migration 2021-05-03 17:57:27 +03:00
allegroai
179661a0d4 Rename default input and output models
Better handling of backwards compatibility in task models
Code cleanup
2021-05-03 17:56:50 +03:00
allegroai
3d22ca1888 Escape task.container and task.execution.model_labels fields in DB 2021-05-03 17:56:17 +03:00
allegroai
fdf6798d0c Don't unset Task's execution.queue on dequeue 2021-05-03 17:54:16 +03:00
allegroai
9d9a44b927 Add skip_empty parameter in get_configuration_names 2021-05-03 17:53:56 +03:00
allegroai
dad935e81d Remove webserver project 2021-05-03 17:53:24 +03:00
allegroai
a75534ec34 Add batch operations support 2021-05-03 17:52:54 +03:00
allegroai
eab33de97e Add bcrypt support to fixed user password 2021-05-03 17:52:25 +03:00
allegroai
29de110abb Add support for queue and model metadata 2021-05-03 17:50:25 +03:00
allegroai
2e7f418ee2 Fix Task.container backwards-compatibility
Fix sub-projects
2021-05-03 17:49:48 +03:00
allegroai
dadb996d22 Refactor es_factory to better support override host/port 2021-05-03 17:48:41 +03:00
allegroai
174f692edf Code cleanup 2021-05-03 17:48:24 +03:00
allegroai
f4d5168a20 Add Task.container support 2021-05-03 17:48:01 +03:00
allegroai
5a438e8435 Fix projects.move 2021-05-03 17:47:11 +03:00
allegroai
ce4814dc47 Add field override support in config (using "-" prefix) 2021-05-03 17:46:36 +03:00
allegroai
ef42d0265d Add multi-models support 2021-05-03 17:46:00 +03:00
allegroai
3c5195028e More sub-projects support and fixes 2021-05-03 17:44:54 +03:00
allegroai
0d5174c453 Support iterating over all task metrics in task debug images 2021-05-03 17:43:02 +03:00
allegroai
c034c1a986 Add sub-projects support 2021-05-03 17:42:10 +03:00
allegroai
1b49da8748 Revoke tests account in fixed mode, cleanup 2021-05-03 17:40:41 +03:00
allegroai
26bda01a28 Add missing errors 2021-05-03 17:39:49 +03:00
allegroai
f5008d80ad Optimize and improve tasks/models/projects.delete 2021-05-03 17:39:13 +03:00
allegroai
8b464e7ae6 Return file urls for tasks.delete/reset and models.delete 2021-05-03 17:38:09 +03:00
allegroai
78e4a58c91 Fix API enum fields and add last_iteration to range queries 2021-05-03 17:37:49 +03:00
allegroai
7a4a5eb03e Fix dropping index by name during the migration fails if the index does not exist 2021-05-03 17:36:49 +03:00
allegroai
d029d56508 Support active users in projects 2021-05-03 17:36:04 +03:00
allegroai
6411954002 Improve visibility for distributed lock hanging 2021-05-03 17:35:17 +03:00
allegroai
7f4ad0d1ca Support projects.get_hyperparam_values 2021-05-03 17:34:40 +03:00
allegroai
4cd4b2914d Add range queries
Switch from sematic_version to packaging.version in db migrations
2021-05-03 17:33:47 +03:00
allegroai
1d55710a0b Update max API version 2021-05-03 17:33:12 +03:00
allegroai
8f646043bb Allow enqueueing stopped tasks
More clearml stuff
2021-05-03 17:31:02 +03:00
allegroai
4b11a6efcd Move apiserver to clearml 2021-05-03 17:26:44 +03:00
allegroai
cb3a7c90a8 Move fileserver to clearml 2021-05-03 17:00:38 +03:00
allegroai
074842a122 Improve fileserver delete code 2021-05-03 16:58:11 +03:00
allegroai
749ff4a44f Fix Tasks.reset does not mark children's parent as deleted 2021-05-03 16:57:06 +03:00
allegroai
7d6918ecb0 Fix large plots comparison 2021-05-03 16:55:59 +03:00
allegroai
47184c2833 Fix querying by task parent 2021-05-03 16:55:03 +03:00
allegroai
6434f1028e Update docker-compose files 2021-01-14 12:37:25 +02:00
allegroai
daade08940 Update docker-compose-win10.yml
Remove deprecated docker-compose-unified.yml
2021-01-07 00:21:24 +02:00
Allegro AI
a1d289822f Update docker-compose-unified.yml
Reduce ES watermark
2021-01-06 17:46:09 +02:00
Allegro AI
1ce34f2c74 Update docker-compose-win10.yml
Reduce ES watermark
2021-01-06 17:45:27 +02:00
Allegro AI
c2dc73a71f Update docker-compose.yml
Reduce ES watermark
2021-01-06 17:44:45 +02:00
allegroai
07bb3b5df8 Update README 2021-01-06 00:32:52 +02:00
allegroai
067ef82576 Update README 2021-01-05 22:56:43 +02:00
allegroai
59fc98e0c4 Upgrade Jinja2 version (vulnerability found in older versions) 2021-01-05 20:18:09 +02:00
174 changed files with 12340 additions and 5039 deletions

1
.gitignore vendored
View File

@@ -12,7 +12,6 @@ test-reports
.pytest_cache
venv
*.noseids
build
*.egg-info
.cache
.mypy_cache

View File

@@ -8,28 +8,43 @@
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)
</div>
---
<div align="center">
**v0.16 Upgrade Notice**
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
</div>
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
due to Elasticsearchs usage of the Java Security Manager.
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
(which in any case **does not permit access to data within the Elasticsearch cluster**),
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
---
### ClearML Server
## ClearML Server
#### *Formerly known as Trains Server*
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
It allows multiple users to collaborate and manage their experiments.
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
The **ClearML Server** contains the following components:
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
## System design
![Alt Text](https://allegro.ai/clearml/docs/_images/ClearML_Server_Diagram.png)
![Alt Text](docs/ClearML_Server_Diagram.png)
The **ClearML Server** has two supported configurations:
- Single IP (domain) with the following open ports
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
Launch The **ClearML Server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
- Pre-built Docker Image
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
- Kubernetes
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
## Connecting ClearML to your ClearML Server
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
To have the **ClearML** client use your **ClearML Server** instead:
In order to set up the **ClearML** client to work with your **ClearML Server**:
- Run the `clearml-init` command for an interactive setup.
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
## Restarting ClearML Server
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
```bash
export TRAINS_HOST_IP=server_host_ip_here
export TRAINS_AGENT_GIT_USER=git_username_here
export TRAINS_AGENT_GIT_PASS=git_password_here
export CLEARML_HOST_IP=server_host_ip_here
export CLEARML_AGENT_GIT_USER=git_username_here
export CLEARML_AGENT_GIT_PASS=git_password_here
```
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
@@ -205,15 +219,15 @@ To upgrade your existing **ClearML Server** deployment:
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
## Community & Support
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
Additionally, you can always find us at *clearml@allegro.ai*

View File

@@ -1,3 +1,8 @@
301 {
_: "moved_permanently"
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
}
400 {
_: "bad_request"
1: ["not_supported", "endpoint is not supported"]
@@ -62,6 +67,12 @@
402: ["project_has_tasks", "project has associated tasks"]
403: ["project_not_found", "project not found"]
405: ["project_has_models", "project has associated models"]
407: ["invalid_project_name", "invalid project name"]
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]
@@ -108,6 +119,11 @@
21: ["no_write_permission", "forbidden (modification not allowed)"]
}
410: {
_: "gone"
1: ["not_supported", "thus endpoint is not supported any more"]
}
500 {
_: "server_error"
0: ["general_error", "general server error"]

View File

@@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
)
def parse_value(self, value):
if value is None and not self.required:
if value is NotSet and not self.required:
return self.get_default_value()
try:
# noinspection PyArgumentList

View File

@@ -75,11 +75,17 @@ class CreateUserResponse(Base):
class Credentials(Base):
access_key = StringField(required=True)
secret_key = StringField(required=True)
label = StringField()
class CredentialsResponse(Credentials):
secret_key = StringField()
last_used = DateTimeField(default=None)
last_used_from = StringField()
class CreateCredentialsRequest(Base):
label = StringField()
class CreateCredentialsResponse(Base):

View File

@@ -0,0 +1,25 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
from apiserver.apimodels.base import UpdateResponse
class BatchRequest(Base):
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
class BatchResponse(Base):
succeeded: Sequence[dict] = ListField([dict])
failed: Sequence[dict] = ListField([dict])
class UpdateBatchItem(UpdateResponse):
id: str = StringField()
class UpdateBatchResponse(BatchResponse):
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)

View File

@@ -2,7 +2,7 @@ from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max
@@ -14,12 +14,18 @@ from apiserver.utilities.stringenum import StringEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class MetricVariants(Base):
metric: str = StringField(required=True)
variants: Sequence[str] = ListField(items_types=str)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@@ -38,7 +44,8 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
metric: str = StringField(default=None)
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
@@ -59,8 +66,8 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
class NextDebugImageSampleRequest(Base):
@@ -74,14 +81,34 @@ class LogOrderEnum(StringEnum):
desc = auto()
class LogEventsRequest(Base):
class TaskEventsRequestBase(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
class TaskEventsRequest(TaskEventsRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
batch_size: int = IntField()
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)
@@ -89,7 +116,6 @@ class IterationEvents(Base):
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
@@ -103,3 +129,15 @@ class TaskMetricsRequest(Base):
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ClearScrollRequest(Base):
scroll_id: str = StringField()

View File

@@ -31,3 +31,4 @@ class GetSupportedModesResponse(Base):
server_errors = EmbeddedField(ServerErrors)
sso = DictField([str, type(None)])
sso_providers = ListField([dict])
authenticated = BoolField(default=False)

View File

@@ -0,0 +1,24 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
class MetadataItem(Base):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
class DeleteMetadata(Base):
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)
replace_metadata = BoolField(default=False)

View File

@@ -3,7 +3,12 @@ from six import string_types
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
from apiserver.apimodels.batch import BatchRequest
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetFrameworksRequest(models.Base):
@@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,))
labels = DictField(value_types=string_types + (int,))
tags = ListField(items_types=string_types)
system_tags = ListField(items_types=string_types)
comment = fields.StringField()
@@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = DictField(value_types=[MetadataItem])
class CreateModelResponse(models.Base):
@@ -32,17 +38,40 @@ class CreateModelResponse(models.Base):
created = fields.BoolField(required=True)
class PublishModelRequest(models.Base):
class ModelRequest(models.Base):
model = fields.StringField(required=True)
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class ModelTaskPublishResponse(models.Base):
id = fields.StringField(required=True)
data = fields.EmbeddedField(TaskPublishResponse)
data = fields.EmbeddedField(UpdateResponse)
class PublishModelResponse(UpdateResponse):
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
updated = fields.IntField()
class ModelsPublishManyRequest(BatchRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class DeleteMetadataRequest(DeleteMetadata):
model = fields.StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)

View File

@@ -0,0 +1,19 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField
class Arg(models.Base):
name = fields.StringField(required=True)
value = fields.StringField(required=True)
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)

View File

@@ -1,15 +1,33 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
class ProjectReq(models.Base):
class ProjectRequest(models.Base):
project = fields.StringField(required=True)
class MergeRequest(ProjectRequest):
destination_project = fields.StringField()
class MoveRequest(ProjectRequest):
new_location = fields.StringField()
class DeleteRequest(ProjectRequest):
force = fields.BoolField(default=False)
delete_contents = fields.BoolField(default=False)
class ProjectOrNoneRequest(models.Base):
project = fields.StringField()
include_subprojects = fields.BoolField(default=True)
class GetHyperParamReq(ProjectReq):
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -18,7 +36,33 @@ class ProjectTagsRequest(TagsRequest):
projects = ListField(str)
class ProjectTaskParentsRequest(ProjectReq):
projects = ListField(str)
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest):
tasks_state = ActualEnumField(EntityVisibility)
class ProjectHyperparamValuesRequest(MultiProjectRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)

View File

@@ -2,7 +2,12 @@ from jsonmodels import validators
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetDefaultResp(Base):
@@ -14,12 +19,18 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
class QueueRequest(Base):
queue = StringField(required=True)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
class DeleteRequest(QueueRequest):
force = BoolField(default=False)
@@ -28,6 +39,7 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
class TaskRequest(QueueRequest):
@@ -58,3 +70,11 @@ class QueueMetrics(Base):
class GetMetricsResponse(Base):
queues = ListField(QueueMetrics)
class DeleteMetadataRequest(DeleteMetadata):
queue = StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
queue = StringField(required=True)

View File

@@ -1,16 +1,17 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
DEFAULT_ARTIFACT_MODE,
TaskModelTypes,
)
from apiserver.database.utils import get_options
@@ -43,26 +44,54 @@ class EnqueueResponse(UpdateResponse):
queued = IntField()
class EnqueueBatchItem(UpdateBatchItem):
queued: bool = BoolField()
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
class DequeueResponse(UpdateResponse):
dequeued = IntField()
class DequeueBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
class DequeueManyResponse(BatchResponse):
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
dequeued = DictField()
frames = DictField()
events = DictField()
model_deleted = IntField()
deleted_models = IntField()
urls = DictField()
class ResetBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
deleted_models = IntField()
urls = DictField()
class ResetManyResponse(BatchResponse):
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
class TaskRequest(models.Base):
task = StringField(required=True)
class UpdateRequest(TaskRequest):
class TaskUpdateRequest(TaskRequest):
force = BoolField(default=False)
class UpdateRequest(TaskUpdateRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
force = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
@@ -71,6 +100,8 @@ class EnqueueRequest(UpdateRequest):
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
class SetRequirementsRequest(TaskRequest):
@@ -81,10 +112,6 @@ class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
class PublishResponse(UpdateResponse):
pass
class TaskData(models.Base):
"""
This is a partial description of task can be updated incrementally
@@ -104,6 +131,11 @@ class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class TaskInputModel(models.Base):
name = StringField()
model = StringField()
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
@@ -113,14 +145,15 @@ class CloneRequest(TaskRequest):
new_task_project = StringField()
new_task_hyperparams = DictField()
new_task_configuration = DictField()
new_task_container = DictField()
new_task_input_models = ListField([TaskInputModel])
execution_overrides = DictField()
validate_references = BoolField(default=False)
new_project_name = StringField()
class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArtifactId(models.Base):
@@ -130,13 +163,14 @@ class ArtifactId(models.Base):
)
class DeleteArtifactsRequest(TaskRequest):
class DeleteArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
class MultiTaskRequest(models.Base):
@@ -161,7 +195,7 @@ class ReplaceHyperparams(object):
all = "all"
class EditHyperParamsRequest(TaskRequest):
class EditHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
@@ -169,7 +203,6 @@ class EditHyperParamsRequest(TaskRequest):
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
force = BoolField(default=False)
class HyperParamKey(models.Base):
@@ -177,11 +210,10 @@ class HyperParamKey(models.Base):
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
class DeleteHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
force = BoolField(default=False)
class GetConfigurationsRequest(MultiTaskRequest):
@@ -189,7 +221,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
skip_empty = BoolField(default=True)
class Configuration(models.Base):
@@ -199,17 +231,15 @@ class Configuration(models.Base):
description = StringField()
class EditConfigurationRequest(TaskRequest):
class EditConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
force = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
class DeleteConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArchiveRequest(MultiTaskRequest):
@@ -219,3 +249,54 @@ class ArchiveRequest(MultiTaskRequest):
class ArchiveResponse(models.Base):
archived = IntField()
class TaskBatchRequest(BatchRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
validate_tasks = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
class PublishManyRequest(TaskBatchRequest):
publish_model = BoolField(default=True)
force = BoolField(default=False)
class AddUpdateModelRequest(TaskRequest):
name = StringField(required=True)
model = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
iteration = IntField()
class ModelItemKey(models.Base):
name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1)
)

View File

@@ -2,7 +2,11 @@ from datetime import datetime
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
from apiserver.apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@@ -57,6 +61,7 @@ class AuthBLL:
api_version=str(ServiceRepo.max_endpoint_version()),
server_version=str(get_version()),
server_build=str(get_build_number()),
feature_set="basic",
)
return GetTokenResponse(token=token.decode("ascii"))
@@ -144,7 +149,7 @@ class AuthBLL:
@classmethod
def create_credentials(
cls, user_id: str, company_id: str, role: str = None
cls, user_id: str, company_id: str, role: str = None, label: str = None,
) -> CredModel:
with translate_errors_context():
@@ -153,9 +158,11 @@ class AuthBLL:
if not user:
raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
cred = CredModel(
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
)
user.save()

View File

@@ -1,25 +1,24 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import bucketize
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
@@ -28,19 +27,22 @@ from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
class VariantState(Base):
variant: str = StringField(required=True)
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
class MetricState(Base):
metric: str = StringField(required=True)
variants: Sequence[VariantState] = ListField([VariantState], required=True)
timestamp: int = IntField(default=0)
class TaskScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
@@ -49,7 +51,7 @@ class MetricScrollState(Base):
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@@ -73,7 +75,7 @@ class DebugImagesIterator:
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
@@ -83,8 +85,7 @@ class DebugImagesIterator:
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
state_.tasks = self._init_task_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
"""
@@ -92,16 +93,8 @@ class DebugImagesIterator:
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, state_)
for metric_state in state_.metrics:
metric_state.reset()
self._reinit_outdated_task_states(company_id, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
@@ -116,101 +109,125 @@ class DebugImagesIterator:
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
state.tasks,
)
)
return res
def _reinit_outdated_metric_states(
self, company_id, state: DebugImageEventsScrollState
def _reinit_outdated_task_states(
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, dict],
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
Determine the metrics for which new debug image events were added
since their states were initialized and re-init these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return {}
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
)
requested_metrics = task_metrics[task.id]
return {
stats.metric: stats.event_stats_by_type[
self.EVENT_TYPE.value
].last_update
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
]
and (not requested_metrics or stats.metric in requested_metrics)
}
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
update_times = {
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
}
task_metric_states = {
task_state.task: {
metric_state.metric: metric_state for metric_state in task_state.metrics
}
for task_state in state.tasks
}
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = {
m: task_metrics[task].get(m)
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
}
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
) -> TaskScrollState:
task = old_state.task
updated_state = first(uts for uts in updates if uts.task == task)
if not updated_state:
old_state.reset()
return old_state
updated_metrics = [m.metric for m in updated_state.metrics]
return TaskScrollState(
task=task,
metrics=[
*updated_state.metrics,
*(
old_metric
for old_metric in old_state.metrics
if old_metric.metric not in updated_metrics
),
],
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
company_id,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
state.tasks = [
merge_with_updated_task_states(task_state, updated_task_states)
for task_state in state.tasks
]
def _init_metric_states(
self, company_id: str, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(
self._init_metric_states_for_task, company_id=company_id
),
tasks.items(),
)
)
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id),
task_metrics.items(),
)
return [
TaskScrollState(task=task, metrics=metric_states,)
for task, metric_states in zip(task_metrics, task_metric_states)
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
) -> Sequence[MetricScrollState]:
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
}
},
"query": query,
"aggs": {
"metrics": {
"terms": {
@@ -254,20 +271,17 @@ class DebugImagesIterator:
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
def init_variant_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
Return new variant state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
state = VariantState(variant=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
@@ -275,102 +289,52 @@ class DebugImagesIterator:
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
variants=[
init_variant_scroll_state(variant)
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
task_state: TaskScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
Update task scroll state
"""
if metric.last_max_iter is None:
if not task_state.metrics:
return task_state.task, []
if task_state.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if navigate_earlier and task_state.last_min_iter is not None:
range_condition = {"lt": task_state.last_min_iter}
elif not navigate_earlier and task_state.last_max_iter is not None:
range_condition = {"gt": task_state.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {
@@ -379,15 +343,26 @@ class DebugImagesIterator:
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"metrics": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"url": {"order": "desc"}}
}
}
},
}
},
}
@@ -397,80 +372,44 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []
return task_state.task, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
invalid_iterations = {
(m.metric, v.variant): v.last_invalid_iteration
for m in task_state.metrics
for v in m.variants
}
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return False
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
if is_valid_event(ev["_source"])
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
task_state.last_max_iter = iterations[0]["iter"]
task_state.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]
return task_state.task, iterations

View File

@@ -1,14 +1,15 @@
import base64
import hashlib
import re
import zlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, Dict
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import six
from elasticsearch import helpers
import elasticsearch
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
@@ -20,14 +21,16 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
delete_company_events,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@@ -36,12 +39,16 @@ from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
from apiserver.utilities.json import loads
EVENT_TYPES = set(map(attrgetter("value"), EventType))
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
log = config.logger(__file__)
class PlotFields:
@@ -49,11 +56,16 @@ class PlotFields:
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
flags=re.IGNORECASE,
)
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
@@ -64,7 +76,7 @@ class EventBLL(object):
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
self.events_iterator = EventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:
@@ -86,7 +98,7 @@ class EventBLL(object):
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
actions: List[dict] = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
@@ -96,6 +108,7 @@ class EventBLL(object):
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
@@ -145,6 +158,9 @@ class EventBLL(object):
iter = event.get("iter")
if iter is not None:
iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1
continue
event["iter"] = iter
# used to have "values" to indicate array. no need anymore
@@ -185,7 +201,6 @@ class EventBLL(object):
actions.append(es_action)
action: Dict[dict]
plot_actions = [
action["_source"]
for action in actions
@@ -201,47 +216,56 @@ class EventBLL(object):
)
added = 0
if actions:
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
with translate_errors_context():
if actions:
chunk_size = 500
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
)
if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
@@ -269,6 +293,11 @@ class EventBLL(object):
event[PlotFields.plot_len] = plot_len
if validate:
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
urls = {match for match in self.img_source_regex.findall(plot_str)}
if urls:
event[PlotFields.source_urls] = list(urls)
if compression_threshold and plot_len >= compression_threshold:
event[PlotFields.plot_data] = base64.encodebytes(
zlib.compress(plot_str.encode(), level=1)
@@ -430,6 +459,9 @@ class EventBLL(object):
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
if event_type in (EventType.metrics_plot, EventType.all):
self.uncompress_plots(events)
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
@@ -438,10 +470,16 @@ class EventBLL(object):
task_id: str,
num_last_iterations: int,
event_type: EventType,
metric_variants: MetricVariants = None,
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"aggs": {
@@ -471,7 +509,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@@ -499,9 +537,11 @@ class EventBLL(object):
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
@@ -527,6 +567,8 @@ class EventBLL(object):
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
@@ -535,6 +577,7 @@ class EventBLL(object):
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
@@ -572,7 +615,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@@ -595,6 +638,41 @@ class EventBLL(object):
return events, total_events, next_scroll_id
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
if scroll_id == self.empty_scroll:
return [], None
if scroll_id:
es_res = self.es.scroll(scroll_id=scroll_id, scroll="10m")
else:
if check_empty_data(self.es, company_id, EventType.metrics_plot):
return [], None
es_req = {
"size": 1000,
"_source": [PlotFields.source_urls],
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"exists": {"field": PlotFields.source_urls}},
]
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=EventType.metrics_plot,
body=es_req,
scroll="10m",
)
events, _, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id
def get_task_events(
self,
company_id: str,
@@ -606,19 +684,20 @@ class EventBLL(object):
sort=None,
size=500,
scroll_id=None,
):
no_scroll=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metric:
must.append({"term": {"metric": metric}})
@@ -628,26 +707,24 @@ class EventBLL(object):
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_id,
iters=last_iter_count,
)
if not last_iters:
continue
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"iter": last_iters}},
]
}
tasks_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
)
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@@ -668,10 +745,13 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
if event_type in (EventType.metrics_plot, EventType.all):
self.uncompress_plots(events)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
@@ -682,6 +762,7 @@ class EventBLL(object):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
es_req = {
"size": 0,
"aggs": {
@@ -702,7 +783,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@@ -721,21 +802,24 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
def get_task_latest_scalar_values(
self, company_id, task_id
) -> Tuple[Sequence[dict], int]:
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
return [], 0
query = {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
}
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
},
"query": query,
"aggs": {
"metrics": {
"terms": {
@@ -839,34 +923,47 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters(
self, company_id: str, event_type: EventType, task_id: str, iters: int
):
self,
company_id: str,
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
es_req: dict = {
"size": 0,
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
"tasks": {
"terms": {"field": "task"},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
}
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
if "aggregations" not in es_res:
return []
return {}
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
return {
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
@@ -892,3 +989,35 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
"""
Delete mutliple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return
# noinspection PyBroadException
try:
self.es.clear_scroll(scroll_id=scroll_id)
except elasticsearch.exceptions.NotFoundError:
pass
except elasticsearch.exceptions.RequestError:
pass
except Exception as ex:
log.exception("Failed clearing scroll %s", scroll_id)

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Union, Sequence
from typing import Union, Sequence, Mapping
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
@@ -16,6 +16,9 @@ class EventType(Enum):
all = "*"
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings:
@classproperty
def max_workers(self):
@@ -63,4 +66,31 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)
return es.delete_by_query(
index=es_index, body=body, conflicts="proceed", **kwargs
)
def count_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.count(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"terms": {"variant": variants}},
]
}
}
if variants
else {"term": {"metric": metric}}
for metric, variants in metric_variants.items()
]
return {"bool": {"should": conditions}}

View File

@@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
EventSettings,
search_company_events,
check_empty_data,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
@@ -34,7 +36,12 @@ class EventMetrics:
self.es = es
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -46,7 +53,12 @@ class EventMetrics:
return {}
return self._get_scalar_average_per_iter_core(
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
)
def _get_scalar_average_per_iter_core(
@@ -57,6 +69,7 @@ class EventMetrics:
samples: int,
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -64,6 +77,7 @@ class EventMetrics:
task_id=task_id,
samples=samples,
field=key.field,
metric_variants=metric_variants,
)
if not intervals:
return {}
@@ -197,6 +211,7 @@ class EventMetrics:
task_id: str,
samples: int,
field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
@@ -204,9 +219,14 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req = {
"size": 0,
"query": {"term": {"task": task_id}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@@ -0,0 +1,208 @@
from typing import Optional, Tuple, Sequence, Any
import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
MetricVariants,
get_metric_variants_condition,
count_company_events,
)
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class EventsIterator:
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_key_value: Optional[Any] = None,
metric_variants: MetricVariants = None,
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
**kwargs,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
from_key_value = kwargs.pop("from_timestamp", from_key_value)
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
event_type=event_type,
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=ScalarKey.resolve(key),
)
return res
def count_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
es_result = count_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
return es_result["count"]
def _get_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
key: ScalarKey,
from_key_value: Optional[Any],
metric_variants: MetricVariants = None,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If from_key_field is not set then start either from latest or earliest.
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": query,
"sort": {key.field: "desc" if navigate_earlier else "asc"},
}
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": must + [{"term": {key.field: events[-1][key.field]}}]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
else:
must = [
{"term": {"task": task_id}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}
return query, must
class Scroll(jsonmodels.models.Base):
def get_scroll_id(self) -> str:
return jwt.encode(
self.to_struct(),
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
).decode()
@classmethod
def from_scroll_id(cls, scroll_id: str):
try:
return cls(
**jwt.decode(
scroll_id,
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
)
)
except jwt.PyJWTError:
raise ValueError("Invalid Scroll ID")

View File

@@ -1,127 +0,0 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)

View File

@@ -4,8 +4,10 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from typing import Any
from apiserver.utilities import extract_properties_to_lists
from apiserver.utilities.stringenum import StringEnum
from apiserver.bll.util import extract_properties_to_lists
from apiserver.config_repo import config
log = config.logger(__file__)
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
"""
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
def cast_value(self, value: Any) -> Any:
"""Cast value to appropriate type"""
return value
class TimestampKey(ScalarKey):
"""
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class IterKey(ScalarKey):
"""
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class ISOTimeKey(ScalarKey):
"""

View File

@@ -1,18 +1,130 @@
from typing import Optional, Sequence
from mongoengine import Q
from datetime import datetime
from typing import Callable, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.utils import get_company_or_none_constraint
from apiserver.database.model.task.task import Task, TaskStatus
from .metadata import Metadata
class ModelBLL:
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
) -> Model:
query = dict(company=company_id, id=model_id)
qs = Model.objects(**query)
if only_fields:
qs = qs.only(*only_fields)
model = qs.first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
return model
@classmethod
def publish_model(
cls,
model_id: str,
company_id: str,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
published_task = None
if model.task and publish_task_func:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
return updated, published_task
@classmethod
def delete_model(
cls, model_id: str, company_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
model_id=model_id,
only_fields=("id", "task", "project", "uri"),
)
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(models__input__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
Task._get_collection().update_many(
filter={"_id": {"$in": [t.id for t in using_tasks]}},
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
)
return unarchived

View File

@@ -0,0 +1,111 @@
from typing import Sequence, Union, Mapping
from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
class Metadata:
@staticmethod
def metadata_from_api(
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
) -> dict:
if not api_data:
return {}
if isinstance(api_data, dict):
return {
ParameterKeyEscaper.escape(k): v.to_struct()
for k, v in api_data.items()
}
return {
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
}
@classmethod
def edit_metadata(
cls,
obj: Document,
items: Sequence[MetadataItem],
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
@classmethod
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
for prefix in (
"metadata.",
"-metadata.",
):
paths = [
cls._process_path(path) if path.startswith(prefix) else path
for path in paths
]
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
keys = list(call.data)
call_data = {
safe_key: call.data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}
projection = GetMixin.get_projection(call_data)
if projection:
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
ordering = GetMixin.get_ordering(call_data)
if ordering:
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
return call_data

View File

@@ -1,12 +1,8 @@
from collections import defaultdict
from enum import Enum
from operator import itemgetter
from typing import Sequence, Dict, Optional
from mongoengine import Q
from typing import Sequence, Dict
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@@ -65,34 +61,3 @@ class OrgBLL:
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
return self._task_tags if entity == Tags.Task else self._model_tags
@classmethod
def get_parent_tasks(
cls,
company_id: str,
projects: Sequence[str],
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
parents = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))

View File

@@ -5,6 +5,8 @@ from mongoengine import Q
from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -40,7 +42,9 @@ class _TagsCache:
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project=project)
query &= Q(project__in=project_ids_with_children([project]))
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)

View File

@@ -1 +1,3 @@
from .project_bll import ProjectBLL
from .project_queries import ProjectQueries
from .sub_projects import _ids_with_children as project_ids_with_children

View File

@@ -1,40 +1,200 @@
from datetime import datetime
from typing import Sequence, Optional, Type
import itertools
from collections import defaultdict
from datetime import datetime, timedelta
from functools import reduce
from itertools import groupby
from operator import itemgetter
from typing import (
Sequence,
Optional,
Type,
Tuple,
Dict,
Set,
TypeVar,
Callable,
Mapping,
Any,
)
from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
from apiserver.database.utils import get_options, get_company_or_none_constraint
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from .sub_projects import (
_reposition_project_with_children,
_ensure_project,
_validate_project_name,
_update_subproject_names,
_save_under_parent,
_get_sub_projects,
_ids_with_children,
_ids_with_parents,
_get_project_depth,
)
log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
def merge_project(
cls, company, source_id: str, destination_id: str
) -> Tuple[int, int, Set[str]]:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
Move all the tasks and sub projects from the source project to the destination
Remove the source project
Return the amounts of moved entities and subprojects + set of all the affected project ids
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
source=source_id
)
source = Project.get(company, source_id)
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
return res
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0
for entity_type in (Task, Model):
moved_entities += entity_type.objects(
company=company,
project=source_id,
system_tags__nin=[EntityVisibility.archived.value],
).update(upsert=False, project=destination_id)
moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1
affected = {source.id, *(source.path or [])}
source.delete()
if destination:
destination.update(last_update=datetime.utcnow())
affected.update({destination.id, *(destination.path or [])})
return moved_entities, moved_sub_projects, affected
@staticmethod
def validate_projects_depth(
projects: Sequence[Project], old_parent_depth: int, new_parent_depth: int
):
for current in projects:
current_depth = len(current.path) + 1
if current_depth - old_parent_depth + new_parent_depth > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
@classmethod
def move_project(
cls, company: str, user: str, project_id: str, new_location: str
) -> Tuple[int, Set[str]]:
"""
Move project with its sub projects from its current location to the target one.
If the target location does not exist then it will be created. If it exists then
it should be writable. The source location should be writable too.
Return the number of moved projects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
Project.get_for_writing(company=project.company, id=old_parent_id)
if old_parent_id
else None
)
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
)
new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if (
new_parent
and project_id == new_parent.id
or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
now = datetime.utcnow()
affected = set()
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
return moved, affected
@classmethod
def update(cls, company: str, project_id: str, **fields):
with TimingContext("mongo", "projects_update"):
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
new_name = fields.pop("name", None)
if new_name:
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects(
[project.id], _only=("id", "name", "path")
)[project.id]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
return updated
@classmethod
def create(
@@ -42,15 +202,20 @@ class ProjectBLL:
user: str,
company: str,
name: str,
description: str,
description: str = "",
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Create a new project.
Returns project ID
"""
if _get_project_depth(name) > max_depth:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
name, location = _validate_project_name(name)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -64,7 +229,16 @@ class ProjectBLL:
created=now,
last_update=now,
)
project.save()
parent = _ensure_project(
company=company,
user=user,
name=location,
creation_params=parent_creation_params,
)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
return project.id
@classmethod
@@ -78,13 +252,14 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Find a project named `project_name` or create a new one.
Returns project ID
"""
if not project_id and not project_name:
raise ValueError("project id or name required")
raise errors.bad_request.ValidationError("project id or name required")
if project_id:
project = Project.objects(company=company, id=project_id).only("id").first()
@@ -92,6 +267,7 @@ class ProjectBLL:
raise errors.bad_request.InvalidProjectId(id=project_id)
return project_id
project_name, _ = _validate_project_name(project_name)
project = Project.objects(company=company, name=project_name).only("id").first()
if project:
return project.id
@@ -104,6 +280,7 @@ class ProjectBLL:
tags=tags,
system_tags=system_tags,
default_output_destination=default_output_destination,
parent_creation_params=parent_creation_params,
)
@classmethod
@@ -125,13 +302,566 @@ class ProjectBLL:
company=company,
project_id=project,
project_name=project_name,
description="Auto-generated during move",
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@classmethod
def make_projects_get_all_pipelines(
cls,
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
)
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: cls.archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def completed_after_subquery(additional_cond, time_thresh: datetime):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed after the time_thresh
"if": {
"$and": [
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
]
},
"then": 1,
"else": 0,
}
}
}
def max_started_subquery(condition):
return {
"$max": {
"$cond": {
"if": condition,
"then": "$started",
"else": datetime.min,
}
}
}
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
time_thresh = datetime.utcnow() - timedelta(hours=24)
for state in cls.visibility_states:
if specific_state and state != specific_state:
continue
cond = (
cls.archived_tasks_cond
if state == EntityVisibility.archived
else {"$not": cls.archived_tasks_cond}
)
group_step[state.value] = runtime_subquery(cond)
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
def get_state_filter() -> dict:
if not specific_state:
return {}
if specific_state == EntityVisibility.archived:
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
**cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(),
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
T = TypeVar("T")
@staticmethod
def aggregate_project_data(
func: Callable[[T, T], T],
project_ids: Sequence[str],
child_projects: Mapping[str, Sequence[Project]],
data: Mapping[str, T],
) -> Dict[str, T]:
"""
Given a list of project ids and data collected over these projects and their subprojects
For each project aggregates the data from all of its subprojects
"""
aggregated = {}
if not data:
return aggregated
for pid in project_ids:
relevant_projects = {p.id for p in child_projects.get(pid, [])} | {pid}
relevant_data = [data for p, data in data.items() if p in relevant_projects]
if not relevant_data:
continue
aggregated[pid] = reduce(func, relevant_data)
return aggregated
@classmethod
def get_project_stats(
cls,
company: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
return_hidden_children: bool = False,
filter_: Mapping[str, Any] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name", "system_tags"))
if include_children
else {}
)
project_ids_with_children = set(project_ids) | {
c.id for c in itertools.chain.from_iterable(child_projects.values())
}
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
company,
project_ids=list(project_ids_with_children),
specific_state=specific_state,
filter_=filter_,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
def sum_status_count(
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
) -> Dict[str, dict]:
return {
section: {
status: nested_get(a, (section, status), default=0)
+ nested_get(b, (section, status), default=0)
for status in set(a.get(section, {})) | set(b.get(section, {}))
}
for section in set(a) | set(b)
}
status_count = cls.aggregate_project_data(
func=sum_status_count,
project_ids=project_ids,
child_projects=child_projects,
data=status_count,
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def sum_runtime(
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
if not section.endswith("max_task_started")
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
for section in set(a) | set(b)
}
runtime = cls.aggregate_project_data(
func=sum_runtime,
project_ids=project_ids,
child_projects=child_projects,
data=runtime,
)
def get_status_counts(project_id, section):
project_runtime = runtime.get(project_id, {})
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
def get_time_or_none(value):
return value if value != datetime.min else None
return {
"status_count": project_section_statuses,
"total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks_24h": project_runtime.get(
f"{section}_recently_completed", 0
),
"last_task_run": get_time_or_none(
project_runtime.get(f"{section}_max_task_started", datetime.min)
),
}
report_for_states = [
s
for s in cls.visibility_states
if not specific_state or specific_state == s
]
stats = {
project: {
task_state.value: get_status_counts(project, task_state.value)
for task_state in report_for_states
}
for project in project_ids
}
def filter_child_projects(project: str) -> Sequence[Project]:
non_filtered_children = child_projects.get(project, [])
if not non_filtered_children or return_hidden_children:
return non_filtered_children
return [
c
for c in non_filtered_children
if not c.system_tags
or EntityVisibility.hidden.value not in c.system_tags
]
children = {
project: sorted(
[
{"id": c.id, "name": c.name}
for c in filter_child_projects(project)
],
key=itemgetter("name"),
)
for project in project_ids
}
return stats, children
@classmethod
def get_active_users(
cls,
company,
project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res
@classmethod
def get_project_tags(
cls,
company_id: str,
include_system: bool,
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod
def get_projects_with_active_user(
cls,
company: str,
users: Sequence[str],
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
) -> Sequence[str]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
If project ids are specified then filter the results by these project ids
"""
query = Q(user__in=users)
if allow_public:
query &= get_company_or_none_constraint(company)
else:
query &= Q(company=company)
user_projects_query = query
if project_ids:
ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = list(res)
if not res:
return res
ids_with_parents = _ids_with_parents(res)
if project_ids:
return [pid for pid in ids_with_parents if pid in project_ids]
return ids_with_parents
@classmethod
def get_task_parents(
cls,
company_id: str,
projects: Sequence[str],
include_subprojects: bool,
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
parents = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))
@classmethod
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@classmethod
def get_model_frameworks(cls, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
return conditions
@classmethod
def calc_own_contents(
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
Use separate aggregation calls on Task/Model instead of lookup
aggregation on projects in order not to hit memory limits on large tasks
"""
if not project_ids:
return {}
pipeline = [
{
"$match": cls.get_match_conditions(
company=company, project_ids=project_ids, filter_=filter_
)
},
{"$project": {"project": 1}},
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}

View File

@@ -0,0 +1,176 @@
from typing import Tuple, Set, Sequence
import attr
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
TaskUrls,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes
from apiserver.timing_context import TimingContext
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
@attr.s(auto_attribs=True)
class DeleteProjectResult:
deleted: int = 0
disassociated_tasks: int = 0
deleted_models: int = 0
deleted_tasks: int = 0
urls: TaskUrls = None
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
).count()
for cls in (Task, Model):
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).count()
return ret
def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_ids = _ids_with_children([project_id])
if not force:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
if not delete_contents:
with TimingContext("mongo", "update_children"):
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(
project=None
)
res = DeleteProjectResult(disassociated_tasks=updated_count)
else:
deleted_models, model_urls = _delete_models(projects=project_ids)
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
)
res = DeleteProjectResult(
deleted_tasks=deleted_tasks,
deleted_models=deleted_models,
urls=TaskUrls(
model_urls=list(model_urls),
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
),
)
affected = {*project_ids, *(project.path or [])}
res.deleted = Project.objects(id__in=project_ids).delete()
return res, affected
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
Children tasks should be deleted in the same api call.
If any child entities are left in another projects then updated their parent task to None
"""
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
if not tasks:
return 0, set(), set()
task_ids = {t.id for t in tasks}
with TimingContext("mongo", "delete_tasks_update_children"):
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
event_urls, artifact_urls = set(), set()
for task in tasks:
event_urls.update(collect_debug_image_urls(company, task.id))
event_urls.update(collect_plot_image_urls(company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls.update(
{
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
)
event_bll.delete_multi_task_events(company, list(task_ids))
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
"""
Delete project models and update the tasks from other projects
that reference them to reference None.
"""
with TimingContext("mongo", "delete_models"):
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set()
model_ids = list({m.id for m in models})
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
urls = {m.uri for m in models if m.uri}
deleted = models.delete()
return deleted, urls

View File

@@ -0,0 +1,370 @@
import json
from collections import OrderedDict
from datetime import datetime
from typing import (
Sequence,
Optional,
Tuple,
)
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .sub_projects import _ids_with_children
log = config.logger(__file__)
class ProjectQueries:
def __init__(self, redis=None):
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
"""
If passed projects is None means top level projects
If passed projects is empty means no project filtering
"""
if include_subprojects:
if not project_ids:
return {}
project_ids = _ids_with_children(project_ids)
if project_ids is None:
project_ids = [None]
if not project_ids:
return {}
return {"project": {"$in": project_ids}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": company_id}
@classmethod
def get_aggregated_project_parameters(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"hyperparams": {"$exists": True, "$gt": {}},
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "section"))
),
"name": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "name"))
),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
ParamValues = Tuple[int, Sequence[str]]
def _get_cached_param_values(
self, key: str, last_update: datetime, allowed_delta_sec=0
) -> Optional[ParamValues]:
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving params cached values: {str(ex)}")
def get_task_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
last_update=last_update,
allowed_delta_sec=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
),
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$gt": {}},
}
},
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.k"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r.get("_id"))
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
def get_model_metadata_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
key: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
last_updated_model = (
Model.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values

View File

@@ -0,0 +1,178 @@
import itertools
from datetime import datetime
from typing import Tuple, Optional, Sequence, Mapping
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model.project import Project
name_separator = "/"
def _get_project_depth(project_name: str) -> int:
return len(list(filter(None, project_name.split(name_separator))))
def _validate_project_name(project_name: str) -> Tuple[str, str]:
"""
Remove redundant '/' characters. Ensure that the project name is not empty
Return the cleaned up project name and location
"""
name_parts = list(filter(None, project_name.split(name_separator)))
if not name_parts:
raise errors.bad_request.InvalidProjectName(name=project_name)
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(
company: str, user: str, name: str, creation_params: dict = None
) -> Optional[Project]:
"""
Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it
Return the project
"""
name = name.strip(name_separator)
if not name:
return None
project = _get_writable_project_from_name(company, name)
if project:
return project
now = datetime.utcnow()
name, location = _validate_project_name(name)
project = Project(
id=database.utils.id(),
user=user,
company=company,
created=now,
last_update=now,
name=name,
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location, creation_params=creation_params)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
return project
def _save_under_parent(project: Project, parent: Optional[Project]):
"""
Save the project under the given parent project or top level (parent=None)
Check that the project location matches the parent name
"""
location, _, _ = project.name.rpartition(name_separator)
if not parent:
if location:
raise ValueError(
f"Project location {location} does not match empty parent name"
)
project.parent = None
project.path = []
project.save()
return
if location != parent.name:
raise ValueError(
f"Project location {location} does not match parent name {parent.name}"
)
project.parent = parent.id
project.path = [*(parent.path or []), parent.id]
project.save()
def _get_writable_project_from_name(
company,
name,
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
) -> Optional[Project]:
"""
Return a project from name. If the project not found then return None
"""
qs = Project.objects(company=company, name=name)
if _only:
qs = qs.only(*_only)
return qs.first()
def _get_sub_projects(
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
) -> Mapping[str, Sequence[Project]]:
"""
Return the list of child projects of all the levels for the parent project ids
"""
qs = Project.objects(path__in=project_ids)
if _only:
_only = set(_only) | {"path"}
qs = qs.only(*_only)
subprojects = list(qs)
return {
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
}
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with all the parent projects
"""
projects = Project.objects(id__in=project_ids).only("id", "path")
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
return list({*(p.id for p in projects), *parent_ids})
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with the ids of all the subprojects
"""
subprojects = Project.objects(path__in=project_ids).only("id")
return list({*project_ids, *(child.id for child in subprojects)})
def _update_subproject_names(
project: Project,
children: Sequence[Project],
old_name: str,
update_path: bool = False,
old_path: Sequence[str] = None,
) -> int:
"""
Update sub project names when the base project name changes
Optionally update the paths
"""
updated = 0
for child in children:
child_suffix = name_separator.join(
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
)
updates = {"name": name_separator.join((project.name, child_suffix))}
if update_path:
updates["path"] = project.path + child.path[len(old_path) :]
updated += child.update(upsert=False, **updates)
return updated
def _reposition_project_with_children(
project: Project, children: Sequence[Project], parent: Project
) -> int:
new_location = parent.name if parent else None
old_name = project.name
old_path = project.path
project.name = name_separator.join(
filter(None, (new_location, project.name.split(name_separator)[-1]))
)
_save_under_parent(project, parent=parent)
moved = 1 + _update_subproject_names(
project=project,
children=children,
old_name=old_name,
update_path=True,
old_path=old_path,
)
return moved

View File

@@ -32,6 +32,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[dict] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
@@ -43,6 +44,7 @@ class QueueBLL(object):
name=name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
last_update=now,
)
queue.save()
@@ -124,14 +126,27 @@ class QueueBLL(object):
)
queue.delete()
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_all(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
company=company_id,
parameters=query_dict,
query_dict=query_dict,
ret_params=ret_params,
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
@@ -141,6 +156,7 @@ class QueueBLL(object):
company=company_id,
query_dict=query_dict,
override_projection=projection,
ret_params=ret_params,
)
queue_workers = defaultdict(list)
@@ -171,13 +187,15 @@ class QueueBLL(object):
if any(e.task == task_id for e in queue.entries):
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
entry = Entry(added=datetime.utcnow(), task=task_id)
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
push__entries=entry, last_update=datetime.utcnow(), upsert=False
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
if not res:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
@@ -217,7 +235,6 @@ class QueueBLL(object):
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
entries_to_remove = [e for e in queue.entries if e.task == task_id]
query = dict(id=queue_id, company=company_id)
@@ -225,6 +242,9 @@ class QueueBLL(object):
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return len(entries_to_remove) if res else 0
def reposition_task(

View File

@@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
def get_or_create_state_core(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
) -> T:
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
return state
@contextmanager
def get_or_create_state(
self,
@@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
state = self.get_or_create_state_core(
state_id=state_id, init_state=init_state, validate_state=validate_state
)
try:
yield state

View File

@@ -45,7 +45,7 @@ class StatisticsReporter:
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
Note: in trains we usually have only a single company
Note: in clearml we usually have only a single company
"""
if not cls.supported:
return

View File

@@ -3,5 +3,4 @@ from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
split_by,
)

View File

@@ -1,10 +1,10 @@
from hashlib import md5
from operator import itemgetter
from typing import Sequence
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
Calculate id from 'key' and 'mode' fields
Return hash on on the id so that it will not contain mongo illegal characters
"""
key_hash: str = md5(artifact["key"].encode()).hexdigest()
key_hash: str = hash_field_name(artifact["key"])
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
return f"{key_hash}_{mode}"
@@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
nested_set(
fields,
artifacts_field,
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
value=sorted(artifacts.values(), key=itemgetter("key")),
)

View File

@@ -175,21 +175,23 @@ class HyperParams:
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
) -> Dict[str, list]:
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
*([skip_empty_condition] if skip_empty else []),
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline)
return {

View File

@@ -1,11 +1,10 @@
import itertools
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Optional
import dpath
from apiserver.apierrors import errors
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
return list(
itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
)
return [
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
(("execution", "model_desc"), "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
legacy_params = nested_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
not fields.get(new_params_field)
and previous_task
and previous_task[new_params_field]
):
@@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
nested_set(fields, new_path, new_param)
nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
@@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, escaped_params)
fields[param_field] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False):
@@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
@@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
fields[param_field] = unescaped_params
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
("hyperparams", ("execution", "parameters"), True),
("configuration", ("execution", "model_desc"), False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
fields.get(new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
nested_set(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
@@ -174,7 +175,7 @@ def _process_path(path: str):
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
if len(parts) < 2 or len(parts) > 4:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
@@ -184,7 +185,8 @@ def _process_path(path: str):
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
("execution.model_desc", "configuration"),
("execution.docker_cmd", "container")
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]

View File

@@ -1,14 +1,14 @@
from collections import OrderedDict
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
@@ -21,21 +21,28 @@ from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
external_task_types,
ModelItem,
Models,
DEFAULT_ARTIFACT_MODE,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
)
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -44,22 +51,9 @@ project_bll = ProjectBLL()
class TaskBLL:
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
def __init__(self, events_es=None, redis=None):
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
@@ -151,19 +145,20 @@ class TaskBLL:
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
def validate_input_models(task, allow_only_public=False):
if not task.models.input:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
model_ids = set(m.model for m in task.models.input)
models = Model.objects(
Q(id__in=model_ids) & get_company_or_none_constraint(company)
).only("id")
missing = model_ids - {m.id for m in models}
if missing:
raise errors.bad_request.InvalidModelId(models=missing)
return model
return
@classmethod
def clone_task(
@@ -179,7 +174,9 @@ class TaskBLL:
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
container: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
) -> Tuple[Task, dict]:
@@ -195,10 +192,29 @@ class TaskBLL:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
now = datetime.utcnow()
if input_models:
input_models = [
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
]
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
execution_model_overriden = execution_overrides.get("model") is not None
execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model:
input_models = [
ModelItem(
model=execution_model,
name=TaskModelNames[TaskModelTypes.input],
updated=now,
)
]
docker_cmd = execution_overrides.pop("docker_cmd", None)
if not container and docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
container = {"image": image, "arguments": arguments}
artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {}
@@ -207,6 +223,8 @@ class TaskBLL:
if legacy_value is not None:
params_dict["execution"] = legacy_value
escape_dict_field(execution_overrides, "model_labels")
execution_dict.update(execution_overrides)
params_prepare_for_save(params_dict, previous_task=task)
@@ -216,7 +234,7 @@ class TaskBLL:
execution_dict["artifacts"] = {
k: a
for k, a in artifacts.items()
if a.get("mode") != ArtifactModes.output
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
}
execution_dict.pop("queue", None)
@@ -227,12 +245,10 @@ class TaskBLL:
project_name=new_project_name,
user=user_id,
company=company_id,
description="Auto-generated while cloning",
description="",
)
new_project_data = {"id": project, "name": new_project_name}
now = datetime.utcnow()
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
if not input_tags:
return input_tags
@@ -240,10 +256,16 @@ class TaskBLL:
return [
tag
for tag in input_tags
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
if tag
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else None
)
new_task = Task(
id=create_id(),
user=user_id,
@@ -253,7 +275,7 @@ class TaskBLL:
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
@@ -262,13 +284,15 @@ class TaskBLL:
output=Output(destination=task.output.destination)
if task.output
else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
@@ -295,7 +319,7 @@ class TaskBLL:
def validate(
cls,
task: Task,
validate_model=True,
validate_models=True,
validate_parent=True,
validate_project=True,
):
@@ -307,6 +331,7 @@ class TaskBLL:
if (
validate_parent
and task.parent
and not task.parent.startswith(deleted_prefix)
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
@@ -318,49 +343,8 @@ class TaskBLL:
if validate_project and not project:
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]},
**({"project": {"$in": project_ids}} if project_ids else {}),
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
if validate_models:
cls.validate_input_models(task)
@staticmethod
def set_last_update(
@@ -372,6 +356,7 @@ class TaskBLL:
tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
count = 0
for task in tasks:
updates = extra_updates
if task.status == TaskStatus.in_progress and task.started:
@@ -381,12 +366,13 @@ class TaskBLL:
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
count += Task.objects(id=task.id, company=company_id).update(
upsert=False,
last_update=last_update,
last_change=last_update,
**updates,
)
return count
@staticmethod
def update_statistics(
@@ -449,226 +435,13 @@ class TaskBLL:
}
extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update(
return TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
last_update=last_update,
**extra_updates,
)
@classmethod
def model_set_ready(
cls,
model_id: str,
company_id: str,
publish_task: bool,
force_publish_task: bool = False,
) -> tuple:
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
elif model.ready:
raise errors.bad_request.ModelIsReady(**query)
published_task_data = {}
if model.task and publish_task:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
published_task_data["data"] = cls.publish_task(
task_id=model.task,
company_id=company_id,
publish_model=False,
force=force_publish_task,
)
published_task_data["id"] = model.task
updated = model.update(upsert=False, ready=True)
return updated, published_task_data
@classmethod
def publish_task(
cls,
task_id: str,
company_id: str,
publish_model: bool,
force: bool,
status_reason: str = "",
status_message: str = "",
) -> dict:
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
@classmethod
def stop_task(
cls,
task_id: str,
company_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,
@@ -677,10 +450,10 @@ class TaskBLL:
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
).execute(enqueue_status=None)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):

View File

@@ -0,0 +1,278 @@
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
import attr
from boltons.iterutils import partition
from mongoengine import QuerySet, Document
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.timing_context import TimingContext
event_bll = EventBLL()
T = TypeVar("T", bound=Document)
class DocumentGroup(List[T]):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type: Type[T], documents: Iterable[T]):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
@property
def ids(self) -> Set[str]:
return {obj.id for obj in self}
def objects(self, *args, **kwargs) -> QuerySet:
return self.type.objects(id__in=self.ids, *args, **kwargs)
class TaskOutputs(Generic[T]):
"""
Split task outputs of the same type by the ready state
"""
published: DocumentGroup[T]
draft: DocumentGroup[T]
def __init__(
self,
is_published: Callable[[T], bool],
document_type: Type[T],
children: Iterable[T],
):
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x),
partition(children, key=is_published),
)
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
event_urls: Sequence[str]
artifact_urls: Sequence[str]
def __add__(self, other: "TaskUrls"):
if not other:
return self
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
event_urls=list(set(self.event_urls) | set(other.event_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@attr.s(auto_attribs=True)
class CleanupResult:
"""
Counts of objects modified in task cleanup operation
"""
updated_children: int
updated_models: int
deleted_models: int
urls: TaskUrls = None
def __add__(self, other: "CleanupResult"):
if not other:
return self
return CleanupResult(
updated_children=self.updated_children + other.updated_children,
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
)
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
urls = set()
next_scroll_id = None
with TimingContext("es", "collect_plot_image_urls"):
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
metrics = event_bll.get_metrics_and_variants(
company_id=company, task_id=task, event_type=EventType.metrics_image
)
if not metrics:
return set()
task_metrics = {task: {m: [] for m in metrics}}
scroll_id = None
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company,
task_metrics=task_metrics,
iter_count=10,
state_id=scroll_id,
)
if not res.metric_events or not any(
iterations for _, iterations in res.metric_events
):
break
scroll_id = res.next_scroll_id
for task, iterations in res.metric_events:
urls.update(ev.get("url") for it in iterations for ev in it["events"])
urls.discard({None})
return urls
def cleanup_task(
task: Task,
force: bool = False,
update_children=True,
return_file_urls=False,
delete_output_models=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
:param task: task object
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models = verify_task_children_and_ouptuts(task, force)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls:
event_urls = collect_debug_image_urls(task.company, task.id)
event_urls.update(collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls = {
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
deleted_task_id = f"{deleted_prefix}{task.id}"
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
)
else:
updated_children = 0
if models.draft and delete_output_models:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
if models.published and update_children:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
)
if return_file_urls
else None,
)
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
)
if not force and models.published:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(models.published),
)
if task.models and task.models.output:
with TimingContext("mongo", "get_task_output_model"):
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
models.published.append(output_model)
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models"
)
input_models = {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)
)
return models

View File

@@ -0,0 +1,411 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
TaskStatus,
Task,
TaskSystemTags,
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
)
from apiserver.utilities.dicts import nested_set
queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task,
company_id=company_id,
only=(
"id",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
requires_write_access=True,
)
try:
TaskBLL.dequeue_and_change_status(
task, company_id, status_message, status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
def unarchive_task(
task: str, company_id: str, status_message: str, status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task, company_id=company_id, only=("id",), requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
)
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
)
return 1, res
def enqueue_task(
task_id: str,
company_id: str,
queue_id: str,
status_message: str,
status_reason: str,
validate: bool = False,
force: bool = False,
) -> Tuple[int, dict]:
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if validate:
TaskBLL.validate(task)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
force=force,
).execute(enqueue_status=task.status)
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
except Exception:
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
).execute(enqueue_status=None)
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
def delete_task(
task_id: str,
company_id: str,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if (
task.status != TaskStatus.created
and EntityVisibility.archived.value not in task.system_tags
and not force
):
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
status_message=status_message,
status_reason=status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
task.delete()
update_project_time(task.project)
return 1, task, cleanup_res
def reset_task(
task_id: str,
company_id: str,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
dequeued = {}
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleaned_up = cleanup_task(
task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
)
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
set__models__output=[],
set__runtime={},
unset__output__result=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
force=force,
status_reason="reset",
status_message="reset",
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
enqueue_status=None,
**updates,
)
return dequeued, cleaned_up, res
def publish_task(
task_id: str,
company_id: str,
force: bool,
publish_model_func: Callable[[str, str], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.models and task.models.output and publish_model_func:
model_id = task.models.output[-1].model
model = (
Model.objects(id=model_id, company=company_id)
.only("id", "ready")
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
def stop_task(
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import TypeVar, Callable, Tuple, Sequence, Union
from typing import Sequence, Union
import attr
import six
@@ -13,6 +13,7 @@ from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
deleted_prefix = "__DELETED__"
@typed_attrs
@@ -105,7 +106,7 @@ def validate_status_change(current_status, new_status):
state_machine = {
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
TaskStatus.in_progress: {
TaskStatus.stopped,
TaskStatus.failed,
@@ -116,6 +117,7 @@ state_machine = {
TaskStatus.closed,
TaskStatus.created,
TaskStatus.failed,
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.published,
TaskStatus.publishing,
@@ -163,22 +165,6 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
T = TypeVar("T")
def split_by(
condition: Callable[[T], bool], items: Sequence[T]
) -> Tuple[Sequence[T], Sequence[T]]:
"""
split "items" to two lists by "condition"
"""
applied = zip(map(condition, items), items)
return (
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:

View File

@@ -1,33 +1,25 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from operator import itemgetter
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
from typing import (
Optional,
Callable,
Dict,
Any,
Set,
Iterable,
Tuple,
Sequence,
TypeVar,
)
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model import AttributedDocument
from apiserver.database.model.settings import Settings
def extract_properties_to_lists(
key_names: Sequence[str],
data: Sequence[dict],
extract_func: Optional[Callable[[dict], Tuple]] = None,
) -> dict:
"""
Given a list of dictionaries and names of dictionary keys
builds a dictionary with the requested keys and values lists
:param key_names: names of the keys in the resulting dictionary
:param data: sequence of dictionaries to extract values from
:param extract_func: the optional callable that extracts properties
from a dictionary and put them in a tuple in the order corresponding to
key_names. If not specified then properties are extracted according to key_names
"""
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))
class SetFieldsResolver:
"""
The class receives set fields dictionary
@@ -115,3 +107,28 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
)
return wrapper
T = TypeVar("T")
def run_batch_operation(
func: Callable[[str], T], ids: Sequence[str]
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
results = list()
failures = list()
for _id in ids:
try:
results.append((_id, func(_id)))
except APIError as err:
failures.append(
{
"id": _id,
"error": {
"codes": [err.code, err.subcode],
"msg": err.msg,
"data": err.error_data,
},
}
)
return results, failures

View File

@@ -113,7 +113,7 @@ class WorkerBLL:
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res:
if not res and not config.get("apiserver.workers.auto_unregister", False):
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
@@ -258,7 +258,7 @@ class WorkerBLL:
tasks_info = {
task.id: task
for task in Task.objects(id__in=task_ids).only(
"name", "started", "last_iteration"
"name", "started", "last_iteration", "active_duration"
)
}
@@ -283,11 +283,7 @@ class WorkerBLL:
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (
int((datetime.utcnow() - task.started).total_seconds() * 1000)
if task.started
else 0
)
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
update_queue_entries(worker.queue)

View File

@@ -6,9 +6,10 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, TypeVar
from typing import List, Any, TypeVar, Sequence
from pyhocon import ConfigTree, ConfigFactory
from boltons.iterutils import first
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
from pyparsing import (
ParseFatalException,
ParseException,
@@ -18,8 +19,8 @@ from pyparsing import (
from apiserver.utilities import json
EXTRA_CONFIG_PATHS = ("/opt/trains/config",)
EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR"
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
DEFAULT_PREFIXES = ("clearml", "trains")
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
@@ -30,7 +31,10 @@ class BasicConfig:
default_config_dir = "default"
def __init__(
self, folder: str = None, verbose: bool = True, prefix: str = "trains"
self,
folder: str = None,
verbose: bool = True,
prefix: Sequence[str] = DEFAULT_PREFIXES,
):
folder = (
Path(folder)
@@ -41,8 +45,16 @@ class BasicConfig:
raise ValueError("Invalid configuration folder")
self.verbose = verbose
self.prefix = prefix
self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__"
self.extra_config_path_override_var = [
f"{p.upper()}_CONFIG_DIR" for p in prefix
]
self.prefix = prefix[0]
self.extra_config_values_env_key_prefix = [
f"{p.upper()}{self.extra_config_values_env_key_sep}"
for p in reversed(prefix)
]
self._paths = [folder, *self._get_paths()]
self._config = self._reload()
@@ -67,30 +79,32 @@ class BasicConfig:
def logger(self, name: str) -> logging.Logger:
if Path(name).is_file():
name = Path(name).stem
if name == "__init__" and Path(name).parent.stem:
name = Path(name).parent.stem
path = ".".join((self.prefix, name))
return logging.getLogger(path)
def _read_extra_env_config_values(self) -> ConfigTree:
""" Loads extra configuration from environment-injected values """
result = ConfigTree()
prefix = self.extra_config_values_env_key_prefix
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = (
key[len(prefix) :]
.replace(self.extra_config_values_env_key_sep, ".")
.lower()
)
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
for prefix in self.extra_config_values_env_key_prefix:
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = (
key[len(prefix) :]
.replace(self.extra_config_values_env_key_sep, ".")
.lower()
)
result = self._merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
return result
def _get_paths(self) -> List[Path]:
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths)
value = first(map(getenv, self.extra_config_path_override_var), default_paths)
paths = [
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
@@ -100,7 +114,7 @@ class BasicConfig:
invalid = [path for path in paths if not path.is_dir()]
if invalid:
print(
f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}"
f"WARNING: Invalid paths in {self.extra_config_path_override_var} env var: {' '.join(map(str, invalid))}"
)
return [path for path in paths if path.is_dir()]
@@ -114,13 +128,40 @@ class BasicConfig:
configs = [self._read_recursive(path) for path in self._paths]
return reduce(
lambda last, config: ConfigTree.merge_configs(
lambda last, config: self._merge_configs(
last, config, copy_trees=True
),
configs + [extra_config_values],
ConfigTree(),
)
@classmethod
def _merge_configs(cls, a, b, copy_trees=False, override_prefix="-"):
"""Based on pyhocon.ConfigTree.merge_configs, with dict override support using a `-` key prefix"""
for key, value in b.items():
override = key.startswith(override_prefix)
if override:
key = key[len(override_prefix):]
# if key is in both a and b and both values are dictionary then merge it otherwise override it
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
if copy_trees:
a[key] = a[key].copy()
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
else:
if isinstance(value, ConfigValues):
value.parent = a
value.key = key
if key in a:
value.overriden_value = a[key]
a[key] = value
if a.root:
if b.root:
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
else:
a.history[key] = a.history.get(key, []) + [value]
return a
def _read_recursive(self, conf_root) -> ConfigTree:
conf = ConfigTree()

View File

@@ -3,7 +3,7 @@
debug: false # Debug mode
pretty_json: false # prettify json response
return_stack: true # return stack trace on error
log_calls: true # Log API Calls
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
# valid values are:
@@ -69,7 +69,7 @@
default_expiration_sec: 2592000
# cookie containing auth token, for requests arriving from a web-browser
session_auth_cookie_name: "trains_token_basic"
session_auth_cookie_name: "clearml_token_basic"
# cookie configuration for authorization cookies generated by auth.login
cookies {
@@ -79,9 +79,16 @@
max_age: 99999999999
}
# provide a cookie domain override per company
# cookies_domain_override {
# <company-id>: <domain>
# }
# # A list of fixed users
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
# fixed_users {
# enabled: true
# pass_hashed: false
# users: [
# {
# username: "john"
@@ -105,6 +112,8 @@
workers {
# Auto-register unknown workers on status reports and other calls
auto_register: true
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
auto_unregister: true
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600
@@ -116,9 +125,9 @@
# Check for updates every 24 hours
check_interval_sec: 86400
url: "https://updates.trains.allegro.ai/updates"
url: "https://updates.clear.ml/updates"
component_name: "trains-server"
component_name: "clearml-server"
# GET request timeout
request_timeout_sec: 3.0
@@ -128,7 +137,7 @@
# Note: statistics are sent ONLY if the user has actively opted-in
supported: true
url: "https://updates.trains.allegro.ai/stats"
url: "https://updates.clear.ml/stats"
report_interval_hours: 24
agent_relevant_threshold_days: 30

View File

@@ -16,7 +16,7 @@
backupCount: 3
maxBytes: 10240000,
class: "logging.handlers.RotatingFileHandler",
filename: "/var/log/trains/apiserver.log"
filename: "/var/log/clearml/apiserver.log"
}
}
root {

View File

@@ -28,6 +28,7 @@
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
revoke_in_fixed_mode: true
}
}
}

View File

@@ -0,0 +1,4 @@
max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600

View File

@@ -17,6 +17,10 @@ events_retrieval {
# the max amount of variants to aggregate on
max_variants_count: 100
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
}
# if set then plot str will be checked for the valid json on plot add

View File

@@ -0,0 +1,7 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -10,4 +10,9 @@ featured {
# default featured index for public projects not specified in the order
public_default: 9999
}
sub_projects {
# the max sub project depth
max_depth: 10
}

View File

@@ -9,3 +9,14 @@ non_responsive_tasks_watchdog {
}
multi_task_histogram_limit: 100
hyperparam_values {
# maximal amount of distinct hyperparam values to retrieve
max_count: 100
# max allowed outdate time for the cashed result
cache_allowed_outdate_sec: 60
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -2,6 +2,8 @@ from functools import lru_cache
from os import getenv
from pathlib import Path
from boltons.iterutils import first
from apiserver.config_repo import config
from apiserver.version import __version__
@@ -9,7 +11,9 @@ root = Path(__file__).parent.parent
def _get(prop_name, env_suffix=None, default=""):
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
suffix = env_suffix or prop_name
keys = [f"{p}_SERVER_{suffix}" for p in ("CLEARML", "TRAINS")]
value = first(map(getenv, keys))
if value:
return value

View File

@@ -17,11 +17,18 @@ log = config.logger("database")
strict = config.get("apiserver.mongo.strict", True)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_MONGODB_SERVICE_HOST",
"TRAINS_MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_MONGODB_SERVICE_PORT",
"TRAINS_MONGODB_SERVICE_PORT",
"MONGODB_SERVICE_PORT",
)
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
class DatabaseEntry(models.Base):
@@ -32,45 +39,57 @@ class DatabaseEntry(models.Base):
class DatabaseFactory:
_entries = []
@classmethod
def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry:
return DatabaseEntry(alias=alias, **settings)
@classmethod
def initialize(cls):
db_entries = config.get("hosts.mongo", {})
missing = []
log.info("Initializing database connections")
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
log.info(f"Using override mongodb port {override_port}")
if override_connection_string:
log.info(f"Using override mongodb connection string {override_connection_string}")
else:
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
if override_port:
log.info(f"Using override mongodb port {override_port}")
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_connection_string:
entry.host = override_connection_string
else:
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
try:
entry.validate()
log.info(
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
)
register_connection(alias=alias, host=entry.host)
register_connection(**entry.to_struct())
cls._entries.append(entry)
except ValidationError as ex:
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
if missing:
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
raise ValueError(
"Missing database configuration for %s" % ", ".join(missing)
)
@classmethod
def get_entries(cls):
@@ -91,7 +110,7 @@ class DatabaseFactory:
# reconnection from work so workaround this
# get_connection(entry.alias, reconnect=True)
disconnect(entry.alias)
register_connection(alias=entry.alias, host=entry.host)
register_connection(**entry.to_struct())
get_connection(entry.alias)

View File

@@ -176,6 +176,13 @@ class SafeMapField(MapField, DictValidationMixin):
self.error("Empty keys are not allowed in a MapField")
class NullableStringField(StringField):
def validate(self, value):
if value is None:
return
super(NullableStringField, self).validate(value)
class SafeDictField(DictField, DictValidationMixin):
def validate(self, value):
self._safe_validate(value)

View File

@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
class EntityVisibility(Enum):
active = "active"
archived = "archived"
hidden = "hidden"

View File

@@ -48,7 +48,9 @@ class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
label = StringField()
last_used = DateTimeField()
last_used_from = StringField()
class User(DbModelMixin, AuthDocument):

View File

@@ -1,26 +1,41 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from functools import reduce, partial
from typing import (
Collection,
Sequence,
Union,
Optional,
Type,
Tuple,
Mapping,
Any,
Callable,
Dict,
List,
)
from boltons.iterutils import first, bucketize, partition
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors
from apiserver.apierrors.base import BaseError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import (
get_company_or_none_constraint,
get_fields_choices,
field_does_not_exist,
field_exists,
)
from apiserver.redis_manager import redman
log = config.logger("dbmodel")
@@ -70,6 +85,9 @@ class GetMixin(PropsMixin):
_ordering_key = "order_by"
_search_text_key = "search_text"
_start_key = "start"
_size_key = "size"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
@@ -77,6 +95,7 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
class QueryParameterOptions(object):
@@ -86,6 +105,7 @@ class GetMixin(PropsMixin):
list_fields=("tags", "system_tags", "id"),
datetime_fields=None,
fields=None,
range_fields=None,
):
"""
:param pattern_fields: Fields for which a "string contains" condition should be generated
@@ -97,49 +117,111 @@ class GetMixin(PropsMixin):
self.fields = fields
self.datetime_fields = datetime_fields
self.list_fields = list_fields
self.range_fields = range_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
_default = "in"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
}
_next = _default
_sticky = False
def __init__(self, legacy=False):
self._legacy = legacy
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
def key(self, v):
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"][0]
elif v.startswith(self.op_prefix):
self._next, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky)
)
self.allow_empty = True
return None
next_ = self._next
if not self._sticky:
self._next = self._default
return next_
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
return self.default_mongo_op
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
get_all_query_options = QueryParameterOptions()
class GetManyScrollState(ProperDictMixin, Document):
meta = {"db_alias": Database.backend, "strict": False}
id = StringField(primary_key=True)
position = IntField(default=0)
_cache_manager = None
@classmethod
def get_cache_manager(cls):
if not cls._cache_manager:
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
return cls._cache_manager
@classmethod
def get(
cls: Union["GetMixin", Document],
@@ -183,6 +265,53 @@ class GetMixin(PropsMixin):
parameters, parameters_options
) & cls._prepare_perm_query(company, allow_public=allow_public)
@staticmethod
def _pop_matching_params(
patterns: Sequence[str], parameters: dict
) -> Mapping[str, Any]:
"""
Pop the parameters that match the specified patterns and return
the dictionary of matching parameters
Pop None parameters since they are not the real queries
"""
if not patterns:
return {}
fields = set()
for pattern in patterns:
if pattern.endswith("*"):
prefix = pattern[:-1]
fields.update(
{field for field in parameters if field.startswith(prefix)}
)
elif pattern in parameters:
fields.add(pattern)
pairs = ((field, parameters.pop(field, None)) for field in fields)
return {k: v for k, v in pairs if v is not None}
@classmethod
def _try_convert_to_numeric(cls, value: Union[str, Sequence[str]]):
def convert_str(val: str) -> Union[float, str]:
try:
return float(val)
except ValueError:
return val
if isinstance(value, str):
return convert_str(value)
if isinstance(value, (list, tuple)):
return [convert_str(v) if isinstance(v, str) else v for v in value]
return value
@classmethod
def _get_fixed_field_value(cls, field: str, value):
if field.startswith("last_metrics."):
return cls._try_convert_to_numeric(value)
return value
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@@ -191,7 +320,9 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
@@ -205,22 +336,32 @@ class GetMixin(PropsMixin):
dict_query = {}
query = RegexQ()
if parameters:
parameters = parameters.copy()
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
query &= cls.get_list_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
if data is not None:
dict_query[field] = data
for field, data in cls._pop_matching_params(
patterns=opts.range_fields, parameters=parameters
).items():
query &= cls.get_range_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
@@ -250,17 +391,64 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
)
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
data.fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None
In case the min value is None the records with missing or None value from db are included
"""
if not isinstance(data, (list, tuple)) or len(data) != 2:
raise errors.bad_request.ValidationError(
f"Min and max values should be specified for range field {field}"
)
min_val, max_val = data
if min_val is None and max_val is None:
raise errors.bad_request.ValidationError(
f"At least one of min or max values should be provided for field {field}"
)
mongoengine_field = field.replace(".", "__")
query = {}
if min_val is not None:
query[f"{mongoengine_field}__gte"] = min_val
if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val
q = RegexQ(**query)
if min_val is None:
q |= RegexQ(**{mongoengine_field: None})
return q
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
@@ -271,34 +459,32 @@ class GetMixin(PropsMixin):
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
data = [data]
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
)
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
]
if not allow_empty:
if not queries:
q = RegexQ()
else:
q = RegexQCombination(operation=global_op, children=queries)
if not helper.allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
| RegexQ(**{f"{mongoengine_field}__exists": False})
| RegexQ(**{mongoengine_field: []})
| RegexQ(**{mongoengine_field: None})
)
@classmethod
@@ -326,27 +512,41 @@ class GetMixin(PropsMixin):
return order_by
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
):
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
elif default_page < 0:
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
"""
Validate and extract paging info from from the provided dictionary. Supports default values.
If page is specified then it should be non-negative, if page size is specified then it should be positive
If page size is specified and page is not then 0 page is assumed
If page is specified then page size should be specified too
"""
parameters = parameters or {}
start = parameters.get(cls._start_key)
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
page_size = parameters.get("page_size", default_page_size or max_page_size)
if page_size is not None and page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
)
return default_page, default_page_size
if page_size is not None:
page = page or 0
page_size = min(page_size, max_page_size)
return page * page_size, page_size
if page is not None:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
return None, None
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
@@ -390,6 +590,54 @@ class GetMixin(PropsMixin):
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod
def validate_scroll_size(cls, query_dict: dict) -> int:
size = query_dict.get(cls._size_key)
if not size or not isinstance(size, int) or size < 1:
raise errors.bad_request.ValidationError(
"Integer size parameter greater than 1 should be provided when working with scroll"
)
return size
@classmethod
def get_data_with_scroll_support(
cls,
query_dict: dict,
data_getter: Callable[[], Sequence[dict]],
ret_params: dict,
) -> Sequence[dict]:
"""
Retrieves the data by calling the provided data_getter api
If scroll parameters are specified then put the query_dict 'start' parameter to the last
scroll position and continue retrievals from that position
If refresh_scroll is requested then bring once more the data from the beginning
till the current scroll position
In the end the scroll position is updated and accumulated frames are returned
"""
query_dict = query_dict or {}
state: Optional[cls.GetManyScrollState] = None
if "scroll_id" in query_dict:
size = cls.validate_scroll_size(query_dict)
state = cls.get_cache_manager().get_or_create_state_core(
query_dict.get("scroll_id")
)
if query_dict.get("refresh_scroll"):
query_dict[cls._size_key] = max(state.position, size)
state.position = 0
query_dict[cls._start_key] = state.position
data = data_getter()
if cls._start_key in query_dict:
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
if state:
state.position = query_dict[cls._start_key]
cls.get_cache_manager().set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id
return data
@classmethod
def get_many_with_join(
cls,
@@ -400,6 +648,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@@ -435,6 +684,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
ret_params=ret_params,
)
def projection_func(doc_type, projection, ids):
@@ -448,6 +698,12 @@ class GetMixin(PropsMixin):
return helper.project(results, projection_func)
@classmethod
def _get_collation_override(cls, field: str) -> Optional[dict]:
return first(
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
)
@classmethod
def get_many(
cls,
@@ -459,6 +715,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@@ -485,6 +742,13 @@ class GetMixin(PropsMixin):
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:return: A list of objects matching the query.
"""
override_collation = None
if query_dict:
for field in query_dict:
override_collation = cls._get_collation_override(field)
if override_collation:
break
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
@@ -497,14 +761,22 @@ class GetMixin(PropsMixin):
_query = (q & query) if query else q
if return_dicts:
return cls._get_many_override_none_ordering(
data_getter = partial(
cls._get_many_override_none_ordering,
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
)
return cls._get_many_no_company(
query=_query, parameters=parameters, override_projection=override_projection
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
@classmethod
@@ -528,6 +800,7 @@ class GetMixin(PropsMixin):
query: Q,
parameters=None,
override_projection=None,
override_collation=None,
):
"""
Fetch all documents matching a provided query.
@@ -547,12 +820,16 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
if order_by and not override_collation:
override_collation = cls._get_collation_override(order_by[0])
start, size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if override_collation:
qs = qs.collation(collation=override_collation)
if search_text:
qs = qs.search_text(search_text)
if order_by:
@@ -566,18 +843,47 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
if start is not None and size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
qs = qs.skip(start).limit(size)
return qs
@classmethod
def _get_queries_for_order_field(
cls, query: Q, order_field: str
) -> Union[None, Tuple[Q, Q]]:
"""
In case the order_field is one of the cls fields and the sorting is ascending
then return the tuple of 2 queries:
1. original query with not empty constraint on the order_by field
2. original query with empty constraint on the order_by field
"""
if not order_field or order_field.startswith("-") or "[" in order_field:
return
mongo_field_name = order_field.replace(".", "__")
mongo_field = first(
v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name
)
if isinstance(mongo_field, ListField):
params = {"is_list": True}
elif isinstance(mongo_field, StringField):
params = {"empty_value": ""}
else:
params = {}
non_empty = query & field_exists(mongo_field_name, **params)
empty = query & field_does_not_exist(mongo_field_name, **params)
return non_empty, empty
@classmethod
def _get_many_override_none_ordering(
cls: Union[Document, "GetMixin"],
query: Q = None,
parameters: dict = None,
override_projection: Collection[str] = None,
override_collation: dict = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
@@ -600,7 +906,10 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
if size is not None and size <= 0:
return []
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@@ -610,32 +919,17 @@ class GetMixin(PropsMixin):
order_field = first(
field for field in order_by if not field.startswith("$")
)
if (
order_field
and not order_field.startswith("-")
and "[" not in order_field
):
params = {}
mongo_field = order_field.replace(".", "__")
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
res = cls._get_queries_for_order_field(query, order_field)
if res:
query_sets = [cls.objects(q) for q in res]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(
v
for k, v in cls._field_collation_overrides.items()
if order_field.startswith(k)
)
if collation_override:
query_sets = [
qs.collation(collation=collation_override) for qs in query_sets
]
if order_field and not override_collation:
override_collation = cls._get_collation_override(order_field)
if override_collation:
query_sets = [
qs.collation(collation=override_collation) for qs in query_sets
]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
@@ -647,25 +941,28 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
start = page * page_size
for qs in query_sets:
qs_size = qs.count()
if qs_size < start:
start -= qs_size
continue
last_set = len(query_sets) - 1
for i, qs in enumerate(query_sets):
last_size = len(ret)
ret.extend(
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
for obj in (qs.skip(start) if start else qs).limit(size)
)
if len(ret) >= page_size:
added = len(ret) - last_size
if added > 0:
start = 0
size = max(0, size - added)
elif i != last_set:
start -= min(start, qs.count())
if size <= 0:
break
start = 0
page_size -= len(ret)
return ret

View File

@@ -0,0 +1,44 @@
from typing import Sequence, Type
from mongoengine import EmbeddedDocument, StringField, Document
from pymongo import UpdateOne
from pymongo.collection import Collection
from apiserver.database.model.base import ProperDictMixin
class MetadataItem(EmbeddedDocument, ProperDictMixin):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
collection: Collection = cls._get_collection()
res = collection.update_one(
filter={"_id": _id},
update={
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
},
array_filters=[
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
],
upsert=False,
)
if len(items) == 1 and res.modified_count == 1:
return res.modified_count
requests = [
UpdateOne(
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
update={"$push": {"metadata": item}},
)
for item in items
]
res = collection.bulk_write(requests)
return 1 if res.modified_count else 0
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)

View File

@@ -1,17 +1,30 @@
from mongoengine import Document, StringField, DateTimeField, BooleanField
from mongoengine import (
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.company import Company
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.user import User
class Model(DbModelMixin, Document):
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -19,6 +32,7 @@ class Model(DbModelMixin, Document):
"parent",
"project",
"task",
"last_update",
("company", "framework"),
("company", "name"),
("company", "user"),
@@ -50,14 +64,14 @@ class Model(DbModelMixin, Document):
"project",
"task",
"parent",
"metadata.*",
),
datetime_fields=("last_update",),
)
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field="Model", required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
@@ -69,7 +83,11 @@ class Model(DbModelMixin, Document):
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -1,4 +1,4 @@
from mongoengine import StringField, DateTimeField, IntField
from mongoengine import StringField, DateTimeField, IntField, ListField
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
@@ -10,13 +10,16 @@ class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"parent",
"path",
("company", "name"),
{
"name": "%s.project.main_text_index" % Database.backend,
@@ -34,7 +37,7 @@ class Project(AttributedDocument):
min_length=3,
sparse=True,
)
description = StringField(required=True)
description = StringField()
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
@@ -44,3 +47,5 @@ class Project(AttributedDocument):
logo_url = StringField()
logo_blob = StringField(exclude_by_default=True)
company_origin = StringField(exclude_by_default=True)
parent = StringField(reference_field="Project")
path = ListField(StringField(required=True), exclude_by_default=True)

View File

@@ -4,34 +4,43 @@ from mongoengine import (
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import DbModelMixin, AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
""" Task ID """
added = DateTimeField(required=True)
''' Added to the queue '''
""" Added to the queue """
class Queue(DbModelMixin, Document):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
"db_alias": Database.backend,
"strict": strict,
}
id = StringField(primary_key=True)
@@ -40,7 +49,12 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
tags = SafeSortedListField(
StringField(required=True), default=list, user_set_allowed=True
)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -11,6 +11,5 @@ class Result(object):
class Output(EmbeddedDocument):
destination = StrippedStringField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Sequence
from mongoengine import (
StringField,
@@ -17,6 +17,8 @@ from apiserver.database.fields import (
SafeDictField,
UnionField,
SafeSortedListField,
EmbeddedDocumentListField,
NullableStringField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -79,7 +81,9 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
mode = StringField(
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
)
uri = StringField()
hash = StringField()
content_size = LongField()
@@ -103,17 +107,37 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField()
class TaskModelTypes:
input = "input"
output = "output"
TaskModelNames = {
TaskModelTypes.input: "Input Model",
TaskModelTypes.output: "Output Model",
}
class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
model = StringField(required=True, reference_field="Model")
updated = DateTimeField()
class Models(EmbeddedDocument, ProperDictMixin):
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
docker_cmd = StringField()
queue = StringField()
queue = StringField(reference_field="Queue")
""" Queue ID where task was queued """
@@ -135,12 +159,10 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
"execution.parameters.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
"hyperparams.": AttributedDocument._numeric_locale,
}
meta = {
@@ -153,6 +175,9 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"last_update",
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
@@ -160,14 +185,18 @@ class Task(AttributedDocument):
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
"$name",
"$id",
"$comment",
"$execution.model",
"$output.model",
"$models.input.model",
"$models.output.model",
"$script.repository",
"$script.entry_point",
],
@@ -176,8 +205,8 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"execution.model": 2,
"output.model": 2,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
"script.entry_point": 1,
},
@@ -185,8 +214,19 @@ class Task(AttributedDocument):
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
datetime_fields=("status_changed",),
list_fields=(
"id",
"user",
"tags",
"system_tags",
"type",
"status",
"project",
"parent",
"hyperparams.*",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment"),
)
@@ -198,7 +238,7 @@ class Task(AttributedDocument):
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
@@ -225,7 +265,11 @@ class Task(AttributedDocument):
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
docker_init_script = StringField()
models: Models = EmbeddedDocumentField(Models, default=Models)
container = SafeMapField(field=NullableStringField())
enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True
)
def get_index_company(self) -> str:
"""

View File

@@ -1,12 +1,16 @@
from collections import OrderedDict, defaultdict
from itertools import chain
from collections import OrderedDict
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document, BaseField
from mongoengine import (
EmbeddedDocumentField,
EmbeddedDocumentListField,
EmbeddedDocument,
Document,
)
from mongoengine.base import get_document
from apiserver.database.fields import (
LengthRangeEmbeddedDocumentListField,
@@ -21,11 +25,18 @@ class PropsMixin(object):
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_field_names_per_type = None
__cached_all_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
_document_classes = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, (Document, EmbeddedDocument)):
PropsMixin._document_classes[cls._class_name] = cls
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
@@ -33,37 +44,12 @@ class PropsMixin(object):
return cls.__cached_fields
@classmethod
def get_field_names_for_type(cls, of_type=BaseField):
"""
Return field names per type including subfields
The fields of derived types are also returned
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
def get_all_fields_with_instance(cls):
if cls.__cached_all_fields_with_instance is None:
cls.__cached_all_fields_with_instance = get_fields(
cls, return_instance=True, subfields=True
)
cls.__cached_field_names_per_type[of_type] = names
return cls.__cached_field_names_per_type[of_type]
return cls.__cached_all_fields_with_instance
@classmethod
def get_fields_with_instance(cls, doc_cls):
@@ -83,8 +69,14 @@ class PropsMixin(object):
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
if v == "self":
return cls_.owner_document
doc_cls = PropsMixin._document_classes.get(v)
if doc_cls:
return doc_cls
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
@@ -98,7 +90,7 @@ class PropsMixin(object):
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
".".join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
@@ -106,10 +98,10 @@ class PropsMixin(object):
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
return fields
@@ -120,7 +112,7 @@ class PropsMixin(object):
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
)
try:
field_name, field = next(
@@ -129,7 +121,7 @@ class PropsMixin(object):
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
raise ValueError("Invalid field path %s" % parts[:depth])
translated_parts.append(part)
@@ -145,7 +137,7 @@ class PropsMixin(object):
),
):
current_cls = field.field.document_type
translated_parts.append('*')
translated_parts.append("*")
else:
current_cls = None
@@ -154,7 +146,7 @@ class PropsMixin(object):
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
fields = cls._get_fields_with_attr(cls, "reference_field")
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@@ -169,12 +161,12 @@ class PropsMixin(object):
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
def get_dpath_translated_path(cls, path, separator="."):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:

View File

@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
import argparse
import json
from pathlib import Path
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
from elasticsearch import Elasticsearch
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
return [_send_template(f) for f in files]

View File

@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
es = Elasticsearch(
hosts=cluster_conf.get("hosts", None),
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {})
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
@@ -109,5 +113,7 @@ def init_es_data():
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
http_auth = es_factory.get_credentials(name)
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
log.info(res)

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,19 +1,30 @@
from datetime import datetime
from functools import lru_cache
from os import getenv
from typing import Tuple, Optional
from boltons.iterutils import first
from elasticsearch import Elasticsearch, Transport
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_HOST",
"TRAINS_ELASTIC_SERVICE_HOST",
"ELASTIC_SERVICE_HOST",
"ELASTIC_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_ELASTIC_SERVICE_PORT", "ELASTIC_SERVICE_PORT")
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_PORT",
"TRAINS_ELASTIC_SERVICE_PORT",
"ELASTIC_SERVICE_PORT",
)
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
@@ -23,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override elastic port {OVERRIDE_PORT}")
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
if OVERRIDE_USERNAME:
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
if OVERRIDE_PASSWORD:
log.info("Using override elastic password ********")
_instances = {}
@@ -42,6 +61,10 @@ class InvalidClusterConfiguration(Exception):
pass
class MissingPasswordForElasticUser(Exception):
pass
class ESFactory:
@classmethod
def connect(cls, cluster_name):
@@ -59,18 +82,45 @@ class ESFactory:
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
http_auth = cls.get_credentials(cluster_name)
args = cluster_config.get("args", {})
_instances[cluster_name] = Elasticsearch(
hosts=hosts, transport_class=Transport, **args
hosts=hosts, http_auth=http_auth, **args
)
return _instances[cluster_name]
@classmethod
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
if not cluster_config.get("secure", True):
return None
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
if not elastic_user:
return None
elastic_password = OVERRIDE_PASSWORD or config.get(
"secure.elastic.password", None
)
if not elastic_password:
raise MissingPasswordForElasticUser(
f"cluster={cluster_name}, username={elastic_user}"
)
return elastic_user, elastic_password
@classmethod
def get_all_cluster_names(cls):
return list(config.get("hosts.elastic"))
@classmethod
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
return OVERRIDE_HOST, OVERRIDE_PORT
@classmethod
@lru_cache()
def get_cluster_config(cls, cluster_name):
"""
Returns cluster config for the specified cluster path
@@ -84,14 +134,16 @@ class ESFactory:
raise MissingClusterConfiguration(cluster_name)
def set_host_prop(key, value):
for host in cluster_config.get("hosts", []):
host[key] = value
for entry in cluster_config.get("hosts", []):
entry[key] = value
if OVERRIDE_HOST:
set_host_prop("host", OVERRIDE_HOST)
host, port = cls.get_override_host(cluster_name)
if OVERRIDE_PORT:
set_host_prop("port", OVERRIDE_PORT)
if host:
set_host_prop("host", host)
if port:
set_host_prop("port", port)
return cluster_config
@@ -120,7 +172,9 @@ class ESFactory:
@classmethod
def get_es_timestamp_str(cls):
now = datetime.utcnow()
return now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
return (
now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
)
es_factory = ESFactory()

View File

@@ -4,19 +4,20 @@ from logging import Logger
from pathlib import Path
from mongoengine.connection import get_db
from semantic_version import Version
from packaging.version import Version, parse
from apiserver.database import utils
from apiserver.database import Database
from apiserver.database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
_migrations = "migrations"
_parent_dir = Path(__file__).resolve().parents[1]
_migration_dir = _parent_dir / _migrations
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in utils.get_options(Database)
get_db(alias).collection_names() for alias in utils.get_options(Database)
)
@@ -41,8 +42,8 @@ def _apply_migrations(log: Logger):
log.info(f"Started mongodb migrations")
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
if not _migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {_migration_dir}")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
@@ -50,7 +51,10 @@ def _apply_migrations(log: Logger):
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
for ver, path in (
(parse(f.stem.replace("_", ".")), f)
for f in _migration_dir.glob("*.py")
)
if ver > last_version
}
except ValueError as ex:
@@ -64,7 +68,10 @@ def _apply_migrations(log: Logger):
if empty_dbs:
log.info(f"Skipping migration {script.name} (empty databases)")
else:
spec = importlib.util.spec_from_file_location(script.stem, str(script))
spec = importlib.util.spec_from_file_location(
".".join(("apiserver", _parent_dir.name, _migrations, script.stem)),
str(script),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
@@ -83,7 +90,7 @@ def _apply_migrations(log: Logger):
DatabaseVersion(
id=utils.id(),
num=script.stem,
num=str(script_version),
created=datetime.utcnow(),
desc="Applied on server startup",
).save()

View File

@@ -21,11 +21,11 @@ from typing import (
Union,
Mapping,
IO,
Callable,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from boltons.iterutils import chunked_iter, first
from furl import furl
@@ -33,6 +33,7 @@ from mongoengine import Q
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.project import project_ids_with_children
from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.bll.task.param_utils import (
split_param_name,
@@ -44,11 +45,17 @@ from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.model.task.task import (
Task,
ArtifactModes,
TaskStatus,
TaskModelTypes,
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.tools import safe_get
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
class PrePopulate:
@@ -293,8 +300,9 @@ class PrePopulate:
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
cls.user_cls(id=user_id, name=user_name, company="").save()
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)
@@ -343,31 +351,10 @@ class PrePopulate:
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: IO[bytes]) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
Build content array that contains upgraded tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
@@ -423,6 +410,24 @@ class PrePopulate:
items.append(results[0])
return items
@classmethod
def _check_projects_hierarchy(cls, projects: Set[Project]):
"""
For any exported project all its parents up to the root should be present
"""
if not projects:
return
project_ids = {p.id for p in projects}
orphans = [p.id for p in projects if p.parent and p.parent not in project_ids]
if not orphans:
return
print(
f"ERROR: the following projects are exported without their parents: {orphans}"
)
exit(1)
@classmethod
def _resolve_entities(
cls,
@@ -434,6 +439,7 @@ class PrePopulate:
if projects:
print("Reading projects...")
projects = project_ids_with_children(projects)
entities[cls.project_cls].update(
cls._resolve_type(cls.project_cls, projects)
)
@@ -463,12 +469,16 @@ class PrePopulate:
project_ids = {p.id for p in entities[cls.project_cls]}
entities[cls.project_cls].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
cls._check_projects_hierarchy(entities[cls.project_cls])
task_models = chain.from_iterable(
models
for task in entities[cls.task_cls]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if task.models
for models in (task.models.input, task.models.output)
if models
)
model_ids = {tm.model for tm in task_models}
if model_ids:
print("Reading models...")
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
@@ -634,11 +644,12 @@ class PrePopulate:
"""
Export the requested experiments, projects and models and return the list of artifact files
Always do the export on sorted items since the order of items influence hash
The projects should be sorted by name so that on import the hierarchy is correctly restored from top to bottom
"""
artifacts = []
now = datetime.utcnow()
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id"))
items = sorted(entities[cls_], key=attrgetter("name", "id"))
if not items:
continue
base_filename = cls._get_base_filename(cls_)
@@ -735,6 +746,90 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_model_data(model_data: dict) -> dict:
metadata_key = "metadata"
metadata = model_data.get(metadata_key)
if isinstance(metadata, list):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in metadata
if isinstance(item, dict) and "key" in item
}
model_data[metadata_key] = metadata
return model_data
@staticmethod
def _upgrade_task_data(task_data: dict) -> dict:
"""
Migrate from execution/parameters and model_desc to hyperparams and configuration fiields
Upgrade artifacts list to dict
Migrate from execution.model and output.model to the new models field
Move docker_cmd contents into the container field
:param task_data: Upgraded in place
:return: The upgraded task data
"""
for old_param_field, new_param_field, default_section in (
("execution.parameters", "hyperparams", hyperparams_default_section),
("execution.model_desc", "configuration", None),
):
legacy_path = old_param_field.split(".")
legacy = nested_get(task_data, legacy_path)
if legacy:
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not nested_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
nested_set(task_data, path=new_path, value=new_param)
nested_delete(task_data, legacy_path)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
nested_set(
task_data,
path=artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
models = task_data.get("models", {})
now = datetime.utcnow()
for old_field, type_ in (
("execution.model", TaskModelTypes.input),
("output.model", TaskModelTypes.output),
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
name = TaskModelNames[type_]
if old_model and not any(
m
for m in new_models
if m.get("model") == old_model or m.get("name") == name
):
model_item = {"model": old_model, "name": name, "updated": now}
if type_ == TaskModelTypes.input:
new_models = [model_item, *new_models]
else:
new_models = [*new_models, model_item]
models[type_] = new_models
nested_delete(task_data, old_path)
task_data["models"] = models
docker_cmd_path = ("execution", "docker_cmd")
docker_cmd = nested_get(task_data, docker_cmd_path)
if docker_cmd and not task_data.get("container"):
image, _, arguments = docker_cmd.partition(" ")
task_data["container"] = {"image": image, "arguments": arguments}
nested_delete(task_data, docker_cmd_path)
return task_data
@classmethod
def _import_entity(
cls,
@@ -748,18 +843,14 @@ class PrePopulate:
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0
data_upgrade_funcs: Mapping[Type, Callable] = {
cls.task_cls: cls._upgrade_task_data,
cls.model_cls: cls._upgrade_model_data,
}
for item in cls.json_lines(f):
if cls_ == cls.task_cls:
task_data = json.loads(item)
artifacts_path = ("execution", "artifacts")
artifacts = nested_get(task_data, artifacts_path)
if isinstance(artifacts, list):
nested_set(
task_data,
artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
item = json.dumps(task_data)
upgrade_func = data_upgrade_funcs.get(cls_)
if upgrade_func:
item = json.dumps(upgrade_func(json.loads(item)))
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):

View File

@@ -5,8 +5,7 @@ from pymongo.database import Database, Collection
def migrate_auth(db: Database):
collection: Collection = db["user"]
if "name_1_company_1" in [doc["name"] for doc in collection.list_indexes()]:
collection.drop_index("name_1_company_1")
collection.drop_indexes()
def migrate_backend(db: Database):

View File

@@ -31,8 +31,8 @@ def migrate_auth(db: Database):
if not uuids:
return
collection = db["user"]
collection.drop_index("name_1_company_1")
collection: Collection = db["user"]
collection.drop_indexes()
_switch_uuid(collection=collection, uuid_field="_id", uuids=uuids)

View File

@@ -1,15 +1,6 @@
from collections import Collection
from typing import Sequence
from pymongo.database import Database
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
for collection_name in db.list_collection_names():
if collection_name not in names:
continue
collection: Collection = db[collection_name]
collection.drop_indexes()
from .utils import _drop_all_indices_from_collections
def migrate_auth(db: Database):

View File

@@ -0,0 +1,128 @@
import os
import re
from datetime import datetime
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.services.utils import escape_dict
from apiserver.utilities.dicts import nested_get
from .utils import _drop_all_indices_from_collections
def _migrate_task_models(db: Database):
"""
Move the execution and output models to new models.input and output lists
"""
tasks: Collection = db["task"]
models_field = "models"
now = datetime.utcnow()
fields = {
TaskModelTypes.input: "execution.model",
TaskModelTypes.output: "output.model",
}
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
set_commands = {}
for mode, field in fields.items():
value = nested_get(doc, field.split("."))
if value:
name = TaskModelNames[mode]
model_item = {"model": value, "name": name, "updated": now}
existing_models = nested_get(doc, (models_field, mode), default=[])
existing_models = (
m
for m in existing_models
if m.get("name") != name and m.get("model") != value
)
if mode == TaskModelTypes.input:
updated_models = [model_item, *existing_models]
else:
updated_models = [*existing_models, model_item]
set_commands[f"{models_field}.{mode}"] = updated_models
tasks.update_one(
{"_id": doc["_id"]},
{
"$unset": {field: 1 for field in fields.values()},
**({"$set": set_commands} if set_commands else {}),
},
)
def _migrate_docker_cmd(db: Database):
tasks: Collection = db["task"]
docker_cmd_field = "execution.docker_cmd"
query = {docker_cmd_field: {"$exists": True}}
for doc in tasks.find(filter=query, projection=(docker_cmd_field,)):
set_commands = {}
docker_cmd = nested_get(doc, docker_cmd_field.split("."))
if docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
set_commands["container"] = {"image": image, "arguments": arguments}
tasks.update_one(
{"_id": doc["_id"]},
{
"$unset": {docker_cmd_field: 1},
**({"$set": set_commands} if set_commands else {}),
},
)
def _migrate_model_labels(db: Database):
tasks: Collection = db["task"]
fields = ("execution.model_labels", "container")
query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]}
for doc in tasks.find(filter=query, projection=fields):
set_commands = {}
for field in fields:
data = nested_get(doc, field.split("."))
if not data:
continue
escaped = escape_dict(data)
if data == escaped:
continue
set_commands[field] = escaped
if set_commands:
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
def _migrate_project_names(db: Database):
projects: Collection = db["project"]
regx = re.compile("/", re.IGNORECASE)
for doc in projects.find(filter={"name": regx, "path": {"$in": [None, []]}}):
name = doc.get("name")
if not name:
continue
max_tries = int(os.getenv("CLEARML_MIGRATION_PROJECT_RENAME_MAX_TRIES", 10))
iteration = 0
for iteration in range(max_tries):
new_name = name.replace("/", "_" * (iteration + 1))
try:
projects.update_one({"_id": doc["_id"]}, {"$set": {"name": new_name}})
break
except DuplicateKeyError:
pass
if iteration >= max_tries - 1:
print(f"Could not upgrade the name {name} of the project {doc.get('_id')}")
def migrate_backend(db: Database):
_migrate_task_models(db)
_migrate_docker_cmd(db)
_migrate_model_labels(db)
_migrate_project_names(db)
_drop_all_indices_from_collections(db, ["task*"])

View File

@@ -0,0 +1,22 @@
from pymongo.collection import Collection
from pymongo.database import Database
def _migrate_project_description(db: Database):
projects: Collection = db["project"]
filter = {
"$or": [
{
"$expr": {"$lt": [{"$strLenCP": "$description"}, 100]},
"description": {"$regex": "^Auto-generated at ", "$options": "i"},
},
{"description": {"$regex": "^Auto-generated during move$", "$options": "i"}},
{"description": {"$regex": "^Auto-generated while cloning$", "$options": "i"}},
]
}
for doc in projects.find(filter=filter):
projects.update_one({"_id": doc["_id"]}, {"$unset": {"description": 1}})
def migrate_backend(db: Database):
_migrate_project_description(db)

View File

@@ -0,0 +1,29 @@
from pymongo.collection import Collection
from pymongo.database import Database
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .utils import _drop_all_indices_from_collections
def _convert_metadata(db: Database, name):
collection: Collection = db[name]
metadata_field = "metadata"
query = {metadata_field: {"$exists": True, "$type": 4}}
for doc in collection.find(filter=query, projection=(metadata_field,)):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in doc.get(metadata_field, [])
if isinstance(item, dict) and "key" in item
}
collection.update_one(
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
)
def migrate_backend(db: Database):
collections = ["model", "queue"]
for name in collections:
_convert_metadata(db, name)
_drop_all_indices_from_collections(db, collections)

View File

@@ -0,0 +1,20 @@
from typing import Sequence
from boltons.iterutils import partition
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
"""
Drop all indices for the existing collections from the specified list
"""
prefixes, names = partition(names, key=lambda x: x.endswith("*"))
prefixes = {p.rstrip("*") for p in prefixes}
for collection_name in db.list_collection_names():
if not (
collection_name in names
or any(p for p in prefixes if collection_name.startswith(p))
):
continue
collection: Collection = db[collection_name]
collection.drop_indexes()

View File

@@ -1,18 +1,29 @@
import threading
from os import getenv
from time import sleep
from boltons.iterutils import first
from redis import StrictRedis
from redis.sentinel import Sentinel, SentinelConnectionPool
from rediscluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = ("TRAINS_REDIS_SERVICE_HOST", "REDIS_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = ("TRAINS_REDIS_SERVICE_PORT", "REDIS_SERVICE_PORT")
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_REDIS_SERVICE_HOST",
"TRAINS_REDIS_SERVICE_HOST",
"REDIS_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PORT",
"TRAINS_REDIS_SERVICE_PORT",
"REDIS_SERVICE_PORT",
)
OVERRIDE_PASSWORD_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PASSWORD",
"TRAINS_REDIS_SERVICE_PASSWORD",
"REDIS_SERVICE_PASSWORD",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
@@ -22,99 +33,7 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override redis port {OVERRIDE_PORT}")
class MyPubSubWorkerThread(threading.Thread):
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
super(MyPubSubWorkerThread, self).__init__()
self.daemon = daemon
self.sentinel = sentinel
self.on_new_master = on_new_master
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.msg_sleep_time = msg_sleep_time
self._running = False
self.pubsub = None
def subscribe(self):
if self.pubsub:
try:
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
except Exception:
pass
finally:
self.pubsub = None
subscriptions = {"+switch-master": self.on_new_master}
while not self.pubsub or not self.pubsub.subscribed:
try:
self.pubsub = self.sentinel.pubsub()
self.pubsub.subscribe(**subscriptions)
except Exception as ex:
log.warn(
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
)
sleep(3)
log.info(f"Subscribed to sentinel {self.sentinel_host}")
def run(self):
if self._running:
return
self._running = True
self.subscribe()
while self.pubsub.subscribed:
try:
self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
)
except Exception as ex:
log.warn(
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
)
self.subscribe()
self.pubsub.close()
self._running = False
def stop(self):
# stopping simply unsubscribes from all channels and patterns.
# the unsubscribe responses that are generated will short circuit
# the loop in run(), calling pubsub.close() to clean up the connection
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
# todo,future - multi master clusters?
class RedisCluster(object):
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
self.service_name = service_name
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
self.master = None
self.master_host_port = None
self.reconfigure()
self.sentinel_threads = {}
self.listen()
def reconfigure(self):
try:
self.master_host_port = self.sentinel.discover_master(self.service_name)
self.master = self.sentinel.master_for(self.service_name)
log.info(f"Reconfigured master to {self.master_host_port}")
except Exception as ex:
log.error(f"Error while reconfiguring. {ex.args[0]}")
def listen(self):
def on_new_master(workerThread):
self.reconfigure()
for sentinel in self.sentinel.sentinels:
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
)
self.sentinel_threads[sentinel_host].start()
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class RedisManager(object):
@@ -123,6 +42,9 @@ class RedisManager(object):
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False)
@@ -134,34 +56,19 @@ class RedisManager(object):
if port:
alias_config["port"] = port
db = alias_config.get("db", 0)
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
if password:
alias_config["password"] = password
sentinels = alias_config.get("sentinels", None)
service_name = alias_config.get("service_name", None)
if not is_cluster and sentinels:
raise ConfigError(
"Redis configuration is invalid. mixed regular and cluster mode",
alias=alias,
)
if is_cluster and (not sentinels or not service_name):
raise ConfigError(
"Redis configuration is invalid. missing sentinels or service_name",
alias=alias,
)
if not is_cluster and (not port or not host):
if not port or not host:
raise ConfigError(
"Redis configuration is invalid. missing port or host", alias=alias
)
if is_cluster:
# todo support all redis connection args via sentinel's connection_kwargs
del alias_config["sentinels"]
del alias_config["cluster"]
del alias_config["service_name"]
self.aliases[alias] = RedisCluster(
sentinels, service_name, **alias_config
)
del alias_config["db"]
self.aliases[alias] = RedisCluster(**alias_config)
else:
self.aliases[alias] = StrictRedis(**alias_config)
@@ -169,27 +76,21 @@ class RedisManager(object):
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
if isinstance(obj, RedisCluster):
obj.master.get("health")
return obj.master
else:
obj.get("health")
return obj
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
pool = r.connection_pool
if isinstance(pool, SentinelConnectionPool):
connections = pool.connection_kwargs[
"connection_pool"
]._available_connections
if isinstance(r, RedisCluster):
connections = first(r.connection_pool._available_connections.values())
else:
connections = pool._available_connections
connections = r.connection_pool._available_connections
if len(connections) > 0:
return connections[0].host
else:
if not connections:
return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis"))

View File

@@ -1,8 +1,9 @@
attrs>=19.1.0
bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0
elasticsearch>=7.0.0,<8.0.0
elasticsearch==7.13.3
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
@@ -11,21 +12,23 @@ funcsigs==1.0.2
furl>=2.0.0
gunicorn>=19.7.1
humanfriendly==4.18
jinja2==2.10
jinja2==2.11.3
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.19.1
mongoengine==0.23.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0
pymongo==3.10.1
pymongo[srv]==3.12.0
python-rapidjson>=0.6.3
redis>=2.10.5
redis==3.5.3
redis-py-cluster>=2.1.3
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4
validators>=0.12.4

View File

@@ -1,3 +1,20 @@
metadata_item {
type: object
properties {
key {
type: string
description: The key uniquely identifying the metadata item inside the given entity
}
type {
type: string
description: The type of the metadata item
}
value {
type: string
description: The value stored in the metadata item
}
}
}
credentials {
type: object
properties {
@@ -9,5 +26,67 @@ credentials {
type: string
description: Credentials secret key
}
label {
type: string
description: Optional credentials label
}
}
}
batch_operation {
request {
type: object
required: [ids]
properties {
ids {
type: array
items {type: string}
}
}
}
response {
type: object
properties {
succeeded {
type: array
items {
type: object
properties {
id: {
description: ID of the succeeded entity
type: string
}
}
}
}
failed {
type: array
items {
type: object
properties {
id: {
description: ID of the failed entity
type: string
}
error: {
description: Error info
type: object
properties {
codes {
type: array
items {type: integer}
}
msg {
type: string
}
data {
type: object
additionalProperties: True
}
}
}
}
}
}
}
}
}

View File

@@ -15,11 +15,19 @@ _definitions {
type: string
description: ""
}
label {
type: string
description: Optional credentials label
}
last_used {
type: string
description: ""
format: "date-time"
}
last_used_from {
type: string
description: ""
}
}
}
role {
@@ -222,6 +230,12 @@ create_credentials {
}
}
}
"2.17": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
}
get_credentials {

File diff suppressed because it is too large Load Diff

View File

@@ -85,7 +85,27 @@ supported_modes {
}
}
}
authenticated {
description: "Is user authenticated"
type: boolean
}
}
}
}
}
logout {
authorize: false
allow_roles = [ "*" ]
"2.13" {
description: """ Logout (including SSO, if used)) """
request {
type: object
additionalProperties: false
}
response {
type: object
additionalProperties: false
}
}
}

View File

@@ -1,5 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
@@ -38,6 +39,11 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Model last update time"
type: string
format: "date-time"
}
task {
description: "Task ID of task in which the model was created"
type: string
@@ -55,14 +61,14 @@ _definitions {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
description: "System tags. This field is reserved for system use, please don't use it."
items { type: string }
}
framework {
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
@@ -91,6 +97,39 @@ _definitions {
type: object
additionalProperties: true
}
metadata {
description: "Model metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
published_task_item {
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
type: object
properties {
id {
description: "Task id"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
}
@@ -151,6 +190,40 @@ get_by_id_ex {
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@@ -254,6 +327,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_frameworks {
"2.8" {
@@ -313,7 +409,7 @@ update_for_task {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
@@ -379,7 +475,7 @@ create {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@@ -433,6 +529,15 @@ create {
}
}
}
"2.13": ${create."2.1"} {
metadata {
description: "Model metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
edit {
"2.1" {
@@ -467,7 +572,7 @@ edit {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@@ -521,6 +626,15 @@ edit {
}
}
}
"2.13": ${edit."2.1"} {
metadata {
description: "Model metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
update {
"2.1" {
@@ -549,7 +663,7 @@ update {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
@@ -597,6 +711,42 @@ update {
}
}
}
"2.13": ${update."2.1"} {
metadata {
description: "Model metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
publish_many {
"2.13": ${_definitions.batch_operation} {
description: Publish models
request {
properties {
ids.description: "IDs of the models to publish"
force_publish_task {
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
type: boolean
}
publish_tasks {
description: "Indicates that the associated tasks (if exist) should be published. Optional, the default value is True."
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.updated {
description: "Indicates whether the model was updated"
type: boolean
}
succeeded.items.properties.published_task: ${_definitions.published_task_item}
}
}
}
}
set_ready {
"2.1" {
@@ -627,39 +777,68 @@ set_ready {
type: integer
enum: [0, 1]
}
published_task {
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
type: object
properties {
id {
description: "Task id"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
committed_versions_results {
description: "Committed versions results"
type: array
items {
type: object
additionalProperties: true
}
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
published_task: ${_definitions.published_task_item}
}
}
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
description: Archive models
request {
properties {
ids.description: "IDs of the models to archive"
}
}
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the model was archived"
type: boolean
}
}
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
description: Unarchive models
request {
properties {
ids.description: "IDs of the models to unarchive"
}
}
response {
properties {
succeeded.items.properties.unarchived {
description: "Indicates whether the model was unarchived"
type: boolean
}
}
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete models
request {
properties {
ids.description: "IDs of the models to delete"
force {
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.deleted {
description: "Indicates whether the model was deleted"
type: boolean
}
succeeded.items.properties.url {
description: "The url of the model file"
type: string
}
}
}
@@ -697,6 +876,16 @@ delete {
}
}
}
"2.13": ${delete."2.1"} {
response {
properties {
url {
description: "The url of the model file"
type: string
}
}
}
}
}
make_public {
@@ -777,4 +966,68 @@ move {
}
}
}
add_or_update_metadata {
"2.13" {
description: "Add or update model metadata"
request {
type: object
required: [model, metadata]
properties {
model {
description: "ID of the model"
type: string
}
metadata {
type: array
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_metadata {
"2.13" {
description: "Delete metadata from model"
request {
type: object
required: [ model, keys ]
properties {
model {
description: "ID of the model"
type: string
}
keys {
description: "The list of metadata keys to delete"
type: array
items {type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}

View File

@@ -0,0 +1,47 @@
_description: "Provides a management API for pipelines in the system."
_definitions {
}
start_pipeline {
"2.17" {
description: "Start a pipeline"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task on which the pipeline will be based"
type: string
}
queue {
description: "Queue ID in which the created pipeline task will be enqueued"
type: string
}
args {
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
type: array
items {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
}
}
}
}
}
response {
type: object
properties {
pipeline {
description: "ID of the new pipeline task"
type: string
}
enqueued {
description: "True if the task was successfuly enqueued"
type: boolean
}
}
}
}
}

View File

@@ -42,15 +42,20 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -70,6 +75,18 @@ _definitions {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks_24h {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
last_task_run {
description: "The most recent started time of a task"
type: integer
}
status_count {
description: "Status counts"
type: object
@@ -78,6 +95,10 @@ _definitions {
description: "Number of 'created' tasks in project"
type: integer
}
completed {
description: "Number of 'completed' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
@@ -152,25 +173,47 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
// extra properties
stats: {
stats {
description: "Additional project stats"
"$ref": "#/definitions/stats"
}
sub_projects {
description: "The list of sub projects"
type: array
items {
type: object
properties {
id {
description: "Subproject ID"
type: string
}
name {
description: "Subproject name"
type: string
}
}
}
}
}
}
metric_variant_result {
@@ -242,6 +285,23 @@ _definitions {
}
}
}
urls {
type: object
properties {
model_urls {
type: array
items {type: string}
}
event_urls {
type: array
items {type: string}
}
artifact_urls {
type: array
items {type: string}
}
}
}
}
create {
@@ -249,28 +309,25 @@ create {
description: "Create a new project"
request {
type: object
required :[
name
description
]
required :[name]
properties {
name {
description: "Project name Unique within the company."
type: string
}
description {
description: "Project description. "
description: "Project description."
type: string
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -348,7 +405,7 @@ get_all {
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of dataviews"
description: "Page number, returns a specific page out of the resulting list of projects"
type: integer
minimum: 0
}
@@ -383,11 +440,51 @@ get_all {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
}
"2.13": ${get_all."2.1"} {
request {
properties {
shallow_search {
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
type: boolean
default: false
}
}
}
}
"2.14": ${get_all."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all_ex {
internal: true
@@ -413,6 +510,89 @@ get_all_ex {
}
}
}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
shallow_search {
description: "If set to 'true' then the search with the specified criteria is performed among top level projects only (or if parents specified, among the direct children of the these parents). Otherwise the search is performed among all the company projects (or among all of the descendants of the specified parents)."
type: boolean
default: false
}
check_own_contents {
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks and models are counted"
type: boolean
default: false
}
}
}
response {
properties {
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
}
}
}
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.16": ${get_all_ex."2.15"} {
request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean
default: true
}
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
}
}
}
update {
"2.1" {
@@ -429,23 +609,19 @@ update {
description: "Project name. Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
description {
description: "Project description"
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -470,6 +646,102 @@ update {
}
}
}
move {
"2.13" {
description: "Moves a project and all of its subprojects under the different location"
request {
type: object
required: [project]
properties {
project {
description: "Project id"
type: string
}
new_location {
description: "The name location for the project"
type: string
}
}
}
response {
type: object
properties {
moved {
description: "The number of projects moved"
type: integer
}
}
}
}
}
merge {
"2.13" {
description: "Moves all the source project's contents to the destination project and remove the source project"
request {
type: object
required: [project]
properties {
project {
description: "Project id"
type: string
}
destination_project {
description: "The ID of the destination project"
type: string
}
}
}
response {
type: object
properties {
moved_entities {
description: "The number of tasks and models moved from the merged project into the destination"
type: integer
}
moved_projects {
description: "The number of child projects moved from the merged project into the destination"
type: integer
}
}
}
}
}
validate_delete {
"2.14" {
description: "Validates that the project existis and can be deleted"
request {
type: object
required: [ project ]
properties {
project {
description: "Project ID"
type: string
}
}
}
response {
type: object
properties {
tasks {
description: "The total number of tasks under the project and all its children"
type: integer
}
non_archived_tasks {
description: "The total number of non-archived tasks under the project and all its children"
type: integer
}
models {
description: "The total number of models under the project and all its children"
type: integer
}
non_archived_models {
description: "The total number of non-archived models under the project and all its children"
type: integer
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
@@ -478,7 +750,7 @@ delete {
required: [ project ]
properties {
project {
description: "Project id"
description: "Project ID"
type: string
}
force {
@@ -487,7 +759,6 @@ delete {
type: boolean
default: false
}
}
}
response {
@@ -504,6 +775,32 @@ delete {
}
}
}
"2.13": ${delete."2.1"} {
request {
properties {
delete_contents {
description: "If set to 'true' then the project tasks and models will be deleted. Otherwise their project property will be unassigned. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by the project tasks and models. Returned if the 'delete_contents' was set to 'true'"
"$ref": "#/definitions/urls"
}
deleted_models {
description: "Number of models deleted"
type: integer
}
deleted_tasks {
description: "Number of tasks deleted"
type: integer
}
}
}
}
}
get_unique_metric_variants {
"2.1" {
@@ -530,12 +827,71 @@ get_unique_metric_variants {
}
}
}
"2.13": ${get_unique_metric_variants."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks"
type: boolean
default: true
}
}
}
}
}
get_hyperparam_values {
"2.13" {
description: """Get a list of distinct values for the chosen hyperparameter"""
request {
type: object
required: [section, name]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
section {
description: "Hyperparameter section name"
type: string
}
name {
description: "Hyperparameter name"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct parameter values"
type: integer
}
values {
description: "The list of the unique values for the parameter"
type: array
items {type: string}
}
}
}
}
}
get_hyper_parameters {
"2.9" {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
@@ -572,8 +928,117 @@ get_hyper_parameters {
}
}
}
"2.13": ${get_hyper_parameters."2.9"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks"
type: boolean
default: true
}
}
}
}
}
get_model_metadata_values {
"2.17" {
description: """Get a list of distinct values for the chosen model metadata key"""
request {
type: object
required: [key]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
key {
description: "Metadata key"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct values"
type: integer
}
values {
description: "The list of the unique values"
type: array
items {type: string}
}
}
}
}
}
get_model_metadata_keys {
"2.17" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_project_tags {
"2.17" {
description: "Get user and system tags used for the specified projects and their children"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
@@ -581,7 +1046,6 @@ get_task_tags {
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
@@ -641,7 +1105,7 @@ make_private {
}
get_task_parents {
"2.12" {
description: "Get unique parent tasks for the tasks in the specified pprojects"
description: "Get unique parent tasks for the tasks in the specified projects"
request {
type: object
properties {
@@ -692,4 +1156,15 @@ get_task_parents {
}
}
}
}
"2.13": ${get_task_parents."2.12"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks"
type: boolean
default: true
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,49 @@ _references {
}
_definitions {
include "_common.conf"
change_many_request: ${_definitions.batch_operation} {
request {
properties {
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
properties {
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
update_response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
multi_field_pattern_data {
type: object
properties {
@@ -40,6 +83,24 @@ _definitions {
}
}
}
model_type_enum {
type: string
enum: ["input", "output"]
}
task_model_item {
type: object
required: [ name, model]
properties {
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
}
}
script {
type: object
properties {
@@ -207,6 +268,22 @@ _definitions {
}
}
}
task_models {
type: object
properties {
input {
description: "The list of task input models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
output {
description: "The list of task output models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
}
}
execution {
type: object
properties {
@@ -454,6 +531,15 @@ _definitions {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
}
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
// TODO: will be removed
script {
description: "Script info"
@@ -531,6 +617,28 @@ _definitions {
"$ref": "#/definitions/configuration_item"
}
}
runtime {
description: "Task runtime mapping"
type: object
additionalProperties: true
}
}
}
task_urls {
type: object
properties {
model_urls {
type: array
items {type: string}
}
event_urls {
type: array
items {type: string}
}
artifact_urls {
type: array
items {type: string}
}
}
}
}
@@ -566,6 +674,47 @@ get_by_id_ex {
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@@ -680,6 +829,36 @@ get_all {
}
}
}
"2.14": ${get_all."2.1"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_types {
"2.8" {
@@ -805,6 +984,106 @@ clone {
}
}
}
"2.13": ${clone."2.12"}{
request {
properties {
new_task_input_models {
description: "The list of input models for the cloned task. If not specifed then copied from the original task"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: [string, null] }
}
}
}
}
}
add_or_update_model {
"2.13" {
description: "Add or update task model"
request {
type: object
required: [task, name, model, type]
properties {
task {
description: "ID of the task"
type: string
}
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
type {
description: "The task model type"
"$ref": "#/definitions/model_type_enum"
}
iteration {
description: "Iteration (used to update task statistics)"
type: integer
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_models {
"2.13" {
description: "Delete models from task"
request {
type: object
required: [ task, models ]
properties {
task {
description: "ID of the task"
type: string
}
models {
description: "The list of models to delete"
type: array
items {
type: object
required: [name, type]
properties {
name {
description: "The task model name"
type: string
}
type {
description: "The task model type"
"$ref": "#/definitions/model_type_enum"
}
}
}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
create {
"2.1" {
@@ -884,6 +1163,21 @@ create {
}
}
}
"2.13": ${create."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
}
}
}
}
}
validate {
"2.1" {
@@ -958,6 +1252,21 @@ validate {
additionalProperties: false
}
}
"2.13": ${validate."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
}
}
}
}
}
update {
"2.1" {
@@ -1005,21 +1314,7 @@ update {
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
update_batch {
@@ -1117,16 +1412,22 @@ edit {
}
}
}
response {
type: object
response: ${_definitions.update_response}
}
"2.13": ${edit."2.1"} {
request {
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
fields {
description: "Updated fields names and values"
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
}
runtime {
description: "Task runtime mapping"
type: object
additionalProperties: true
}
@@ -1151,24 +1452,13 @@ reset {
default: false
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
deleted_indices {
description: "List of deleted ES indices that were removed as part of the reset process"
type: array
items { type: string }
}
dequeued {
description: "Response from queues.remove_task"
type: object
additionalProperties: true
}
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events {
description: "Response from events.delete_for_task"
type: object
@@ -1178,16 +1468,126 @@ reset {
description: "Number of output models deleted by the reset"
type: integer
}
updated {
}
}
}
"2.13": ${reset."2.1"} {
request {
properties {
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
reset_many {
"2.13": ${_definitions.batch_operation} {
description: Reset tasks
request {
properties {
ids.description: "IDs of the tasks to reset"
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
succeeded.items.properties.updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
succeeded.items.properties.fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
succeeded.items.properties.deleted_models {
description: "Number of output models deleted by the reset"
type: integer
}
succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
description: Delete tasks
request {
properties {
ids.description: "IDs of the tasks to delete"
move_to_trash {
description: "Move task to trash instead of deleting it. For internal use only, tasks in the trash are not visible from the API and cannot be restored!"
type: boolean
default: false
}
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'in_progress'"
}
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by the tasks. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of the tasks that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
properties {
succeeded.items.properties.deleted {
description: "Indicates whether the task was deleted"
type: boolean
}
succeeded.items.properties.updated_children {
description: "Number of child tasks whose parent property was updated"
type: integer
}
succeeded.items.properties.updated_models {
description: "Number of models whose task property was updated"
type: integer
}
succeeded.items.properties.deleted_models {
description: "Number of deleted output models"
type: integer
}
succeeded.items.properties.urls {
description: "The urls of the files that were uploaded by the task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
@@ -1229,15 +1629,6 @@ delete {
description: "Number of models whose task property was updated"
type: integer
}
updated_versions {
description: "Number of dataset versions whose task property was updated"
type: integer
}
frames {
description: "Response from frames.rollback"
type: object
additionalProperties: true
}
events {
description: "Response from events.delete_for_task"
type: object
@@ -1246,6 +1637,24 @@ delete {
}
}
}
"2.13": ${delete."2.1"} {
request {
properties {
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'"
"$ref": "#/definitions/task_urls"
}
}
}
}
}
archive {
"2.12" {
@@ -1284,6 +1693,58 @@ archive {
}
}
}
archive_many {
"2.13": ${_definitions.batch_operation} {
description: Archive tasks
request {
properties {
ids.description: "IDs of the tasks to archive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the task was archived"
type: boolean
}
}
}
}
}
}
unarchive_many {
"2.13": ${_definitions.batch_operation} {
description: Unarchive tasks
request {
properties {
ids.description: "IDs of the tasks to unarchive"
status_reason {
description: Reason for status change
type: string
}
status_message {
description: Extra information regarding status change
type: string
}
}
}
response {
properties {
succeeded.items.properties.unarchived {
description: "Indicates whether the task was unarchived"
type: boolean
}
}
}
}
}
started {
"2.1" {
description: "Mark a task status as in_progress. Optionally allows to set the task's execution progress."
@@ -1296,24 +1757,13 @@ started {
description: "If not true, call fails if the task status is not 'not_started'"
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
started {
description: "Number of tasks started (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
@@ -1330,18 +1780,17 @@ stop {
description: "If not true, call fails if the task status is not 'in_progress'"
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response}
}
}
stop_many {
"2.13": ${_definitions.change_many_request} {
description: "Request to stop running tasks"
request {
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
ids.description: "IDs of the tasks to stop"
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is not 'in_progress'"
}
}
}
@@ -1359,21 +1808,7 @@ stopped {
description: "If not true, call fails if the task status is not 'stopped'"
}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
failed {
@@ -1386,21 +1821,7 @@ failed {
]
properties.force = ${_references.force_arg}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
close {
@@ -1413,21 +1834,7 @@ close {
]
properties.force = ${_references.force_arg}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
publish {
@@ -1452,26 +1859,21 @@ publish {
}
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response}
}
}
publish_many {
"2.13": ${_definitions.change_many_request} {
description: Publish tasks
request {
properties {
committed_versions_results {
description: "Committed versions results"
type: array
items {
type: object
additionalProperties: true
}
ids.description: "IDs of the tasks to publish"
force = ${_references.force_arg} {
description: "If not true, call fails if the task status is not 'stopped'"
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
publish_model {
description: "Indicates that the task output model (if exists) should be published. Optional, the default value is True."
type: boolean
}
}
}
@@ -1502,23 +1904,39 @@ Fails if the following parameters in the task were not filled:
}
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
queued {
description: "Number of tasks queued (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
description: Enqueue tasks
request {
properties {
ids.description: "IDs of the tasks to enqueue"
queue {
description: "Queue id. If not provided, tasks are added to the default queue."
type: string
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
validate_tasks {
description: "If set then tasks are validated before enqueue"
type: boolean
default: false
}
}
}
response {
properties {
succeeded.items.properties.queued {
description: "Indicates whether the task was queued"
type: boolean
}
}
}
@@ -1534,23 +1952,30 @@ dequeue {
task
]
} ${_references.status_change_request}
response {
type: object
response: ${_definitions.update_response} {
properties {
dequeued {
description: "Number of tasks dequeued (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
dequeue_many {
"2.13": ${_definitions.change_many_request} {
description: Dequeue tasks
request {
properties {
ids.description: "IDs of the tasks to dequeue"
}
}
response {
properties {
succeeded.items.properties.dequeued {
description: "Indicates whether the task was dequeued"
type: boolean
}
}
}
@@ -1576,21 +2001,7 @@ set_requirements {
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
@@ -1606,21 +2017,7 @@ completed {
description: "If not true, call fails if the task status is not in_progress/stopped"
}
} ${_references.status_change_request}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
response: ${_definitions.update_response}
}
}
@@ -1932,6 +2329,11 @@ get_configuration_names {
type: array
items { type: string }
}
skip_empty {
description: If set to 'true' then the names for configurations with missing values are not returned
type: boolean
default: true
}
}
}
response {

View File

@@ -4,7 +4,7 @@ from hashlib import md5
from flask import Flask
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from packaging.version import Version
from apiserver.database import db
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
@@ -25,6 +25,7 @@ from apiserver.server_init.request_handlers import RequestHandlers
from apiserver.service_repo import ServiceRepo
from apiserver.sync import distributed_lock
from apiserver.updates import check_updates_thread
from apiserver.utilities.env import get_bool
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@@ -46,10 +47,13 @@ class AppSequence:
def _attach_request_handlers(self, request_handlers: RequestHandlers):
self.app.before_first_request(request_handlers.before_app_first_request)
self.app.before_request(request_handlers.before_request)
self.app.after_request(request_handlers.after_request)
def _configure(self):
CORS(self.app, **config.get("apiserver.cors"))
Compress(self.app)
if get_bool("CLEARML_COMPRESS_RESP", default=True):
Compress(self.app)
self.app.config["SECRET_KEY"] = config.get(
"secure.http.session_secret.apiserver"

View File

@@ -1,18 +1,24 @@
from functools import partial
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
def before_app_first_request(self):
pass
@@ -23,9 +29,15 @@ class RequestHandlers:
if "/static/" in request.path:
return
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
try:
call = self._create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
load_data_callback = partial(self._load_call_data, req=request)
content, content_type, company = ServiceRepo.handle_call(
call, load_data_callback=load_data_callback
)
if call.result.redirect:
response = redirect(call.result.redirect.url, call.result.redirect.code)
@@ -45,20 +57,53 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies")
kwargs = config.get("apiserver.auth.cookies").copy()
if value is None:
kwargs = kwargs.copy()
# Removing a cookie
kwargs["max_age"] = 0
kwargs["expires"] = 0
response.set_cookie(key, "", **kwargs)
else:
response.set_cookie(key, value, **kwargs)
value = ""
elif not company:
# Setting a cookie, let's try to figure out the company
# noinspection PyBroadException
try:
company = Token.decode_identity(value).company
except Exception:
pass
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
except KeyError:
pass
response.set_cookie(key, value, **kwargs)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def after_request(self, response):
response.headers["server"] = self._server_header
return response
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):
if v.replace(".", "", 1).isdigit():
return float(v) if "." in v else int(v)
if v in ("true", "True", "TRUE"):
return True
if v in ("false", "False", "FALSE"):
return False
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
@@ -76,23 +121,12 @@ class RequestHandlers:
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
if req.args:
self._apply_multi_dict(body, req.args)
if req.form:
self._apply_multi_dict(body, req.form)
call.data = body
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or APICall(
@@ -137,9 +171,6 @@ class RequestHandlers:
auth_cookie=auth_cookie,
)
# Update call data from request
self._update_call_data(call, req)
except PathParsingError as ex:
call = self._call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
@@ -156,3 +187,18 @@ class RequestHandlers:
)
return call
def _load_call_data(self, call: APICall, req):
"""Update call data from request"""
try:
self._update_call_data(call, req)
except BadRequest as ex:
call.set_error_result(msg=ex.description, code=400)
except BaseError as ex:
call.set_error_result(msg=ex.msg, code=ex.code, subcode=ex.subcode)
except APIError as ex:
call.set_error_result(
msg=ex.msg, code=ex.code, subcode=ex.subcode, error_data=ex.error_data
)
except Exception as ex:
call.set_error_result(msg=ex.args[0] if ex.args else type(ex).__name__)

View File

@@ -1,6 +1,6 @@
from typing import Text, Sequence, Callable, Union, Type
from funcsigs import signature
from inspect import signature
from jsonmodels import models
from .apicall import APICall, APICallResult

Some files were not shown because too many files have changed in this diff Show More