Compare commits

519 Commits

Author SHA1 Message Date
clearml
1299ebfcf3 Version bump to v1.17.0 2024-12-05 22:38:25 +02:00
clearml
8c4932c7eb Model files are now deleted from the fileserver on models.delete call 2024-12-05 22:38:06 +02:00
clearml
e48e64a82f Do not throw internal error on invalid file paths 2024-12-05 22:37:15 +02:00
clearml
046a142f36 Do not return the last incomplete interval for worker stats chart 2024-12-05 22:36:33 +02:00
clearml
207b9e4746 Allow all users to access storage APIs 2024-12-05 22:35:16 +02:00
clearml
605fccdef1 Update ElasticSearch version 2024-12-05 22:34:23 +02:00
clearml
8b8d8d6e6f Change model input_size field to string 2024-12-05 22:33:52 +02:00
clearml
97b9bbc4a9 Return created_in_version property in users.get_current_user 2024-12-05 22:32:28 +02:00
clearml
ed60a27d1a Add mem used charts and cpu/gpu counts to model endpoints instance details
For the num of requests serving charts always take the max value from the interval
2024-12-05 22:31:45 +02:00
clearml
17fcaba2cb Add internal script to fix fileserver URLs in mongodb 2024-12-05 22:30:03 +02:00
clearml
83dbf0fcb8 Add age_sec field to loading serving models
Return serving instance charts sorted by instance name
2024-12-05 22:27:52 +02:00
clearml
a3b303fa28 Add support for OneOfEmbeddedField 2024-12-05 22:27:20 +02:00
clearml
543c579a2e Do not allow creating a project that has a name or part of the path matching the existing public project 2024-12-05 22:26:57 +02:00
clearml
41b003f328 Add an error for trying to duplicate a public project 2024-12-05 22:26:05 +02:00
clearml
606bf2c4be Fix mongodb connection when overridden connection string contains connection options 2024-12-05 22:25:35 +02:00
clearml
57ce9446b1 Add _any_/_all_ queries support for datetime fields 2024-12-05 22:25:08 +02:00
clearml
073cc96fb8 Optimize tasks.move 2024-12-05 22:24:40 +02:00
clearml
77e7fb5c13 Add reference field to serving models 2024-12-05 22:24:18 +02:00
clearml
0b61ec2a56 Workers statistics now return 0s for the periods where the worker did not report 2024-12-05 22:23:52 +02:00
clearml
7506a13fe8 Quote all non numeric fields in csv files 2024-12-05 22:23:22 +02:00
clearml
9dfb4b882a Fix tasks/models.edit_tags do not update the task/model last_changed time 2024-12-05 22:22:49 +02:00
clearml
2eee909364 Export csv files fixed for projects containing semicolon in their names 2024-12-05 22:22:12 +02:00
clearml
3bcbc38c4c Add storage service support 2024-12-05 22:21:12 +02:00
clearml
eb755be001 Add model endpoints support 2024-12-05 22:20:11 +02:00
clearml
9997dcc977 Sync API version 2024-12-05 22:18:27 +02:00
clearml
ee9f45ea61 Optimize MongoDB indices usage for large dbs 2024-12-05 22:17:13 +02:00
clearml
a1956cdd83 When removing a task from a queue change the task state only if the task does not think that it is enqueued in some other place 2024-12-05 22:16:14 +02:00
clearml
4b93f1f508 Add queues.clear_queue
Add new parameter 'update_task_status' to queues.remove_task
2024-12-05 22:15:43 +02:00
clearml
2752c4df54 Fixed schema for users.get_current_user 2024-12-05 22:14:37 +02:00
clearml
2332b8589b Update the task execution queue in queues.add_task 2024-12-05 22:14:03 +02:00
clearml
f94cda4e9d Fix user migration 2024-12-05 19:13:55 +02:00
clearml
a84e1ec0d6 Update licenses 2024-12-05 19:13:49 +02:00
clearml
4223fe73d1 Single task/model delete waits for events deletion in order to mitigate too many ES open scrolls due to repeated calls 2024-12-05 19:13:06 +02:00
clearml
f9577f9faa add update_execution_queue parameter to tasks.enqueue 2024-12-05 19:12:26 +02:00
clearml
58b748ddf3 Merge pipeline parameters with original task hyperparameters 2024-12-05 19:11:36 +02:00
clearml
fa41e14625 Allow enqueueing enqueued tasks 2024-12-05 19:10:34 +02:00
clearml
4df5687ecd Do not replace S3 links in data_tool export by default 2024-12-05 19:09:21 +02:00
clearml
9a69c21504 Fix model update for a deleted task 2024-12-05 19:08:26 +02:00
clearml
39c36527e2 Make sure that a task retrieved from a queue is not in aborted status 2024-12-05 19:07:55 +02:00
clearml
f59ef65fa6 Update API version to v2.31 2024-12-05 19:07:34 +02:00
clearml
8f942f0da2 Data tool can now export project trees not starting from the root 2024-12-05 19:06:56 +02:00
clearml
7b5679fd70 Optimize events deletion in tasks.delete_many/reset_many and models.delete_many operations 2024-12-05 19:06:25 +02:00
clearml
5a5f02cead Fix user credentials reset on the apiserver restart 2024-12-05 19:05:45 +02:00
clearml
cfcad6300a Add created to the range fields for tasks 2024-12-05 19:05:29 +02:00
clearml
fd46f3c6f3 Display only one debug image per iteration/metric and variant 2024-12-05 19:03:36 +02:00
clearml
e86b7fd24e Support for first and mean value for task last scalar metrics 2024-12-05 19:02:48 +02:00
clearml
50593f69f8 Allow enqueueing failed tasks 2024-12-05 18:57:06 +02:00
clearml
ba928854e0 MongoDB upgrade to v5.0 2024-12-05 18:54:23 +02:00
allegroai
83a0485518 Fix user credentials reset on apiserver restart 2024-07-17 11:22:52 +03:00
allegroai
f3491cc9b9 Update README 2024-07-07 13:28:40 +03:00
allegroai
7558426bc6 Fix max upload size limit 2024-06-26 11:21:53 +03:00
allegroai
ce01e37c66 Refactor docker compose files: remove legacy, add services agent initialization in Linux 2024-06-26 10:53:43 +03:00
allegroai
92b42d66b7 Remove default credentials and reset existing credentials if none were provided 2024-06-26 10:52:42 +03:00
allegroai
f7d36bea4f Use an auth token in async_urls_delete when contacting the fileserver 2024-06-20 18:00:19 +03:00
allegroai
f1c876089b Add worker_pattern parameter to workers.get_all and get_count endpoints 2024-06-20 17:59:28 +03:00
allegroai
dd0ecb712d Added fileserver.upload.max_upload_size_mb setting 2024-06-20 17:58:33 +03:00
allegroai
fcfc1e8998 Support a more granular distributed lock wait 2024-06-20 17:57:54 +03:00
allegroai
9c210bb4fa Fix fixed users creation/removal 2024-06-20 17:57:23 +03:00
allegroai
14547155cb Delete pipeline steps in pipelines.delete_runs 2024-06-20 17:55:52 +03:00
allegroai
3f34f83a91 Version bump to 1.16.0
API version bump to 2.30
Add missing endpoints to schema
2024-06-20 17:55:17 +03:00
allegroai
da3941e6f2 Upgrade pymongo dependency 2024-06-20 17:53:15 +03:00
allegroai
2e19a18ee4 Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints 2024-06-20 17:52:46 +03:00
allegroai
cdc668e3c8 Fileserver authorization is enabled by default 2024-06-20 17:50:02 +03:00
allegroai
7c9889605a Add token authorization to fileserver 2024-06-20 17:48:54 +03:00
allegroai
5456ee4ebf Data tool export projects by name now includes subprojects + option for exporting all projects added 2024-06-20 17:48:18 +03:00
allegroai
562cb77003 Support getting and clearing task logs using specific metrics 2024-06-20 17:47:39 +03:00
allegroai
91df2bb3b7 Use better token generation for the secret key 2024-06-20 17:46:23 +03:00
allegroai
cb9812caee Do not return any mongodb instructions as a result of task update operations 2024-06-20 17:44:17 +03:00
allegroai
0496582d96 Ensure min interval on workers history charts so that we do not get "saw like" chart due to the missing points in the intervals 2024-06-20 17:43:52 +03:00
allegroai
beff19e104 Fix do not return full file path on errors from the fileserver 2024-06-20 17:43:19 +03:00
pollfly
639b3d59a4 Update docstrings (#246)
Edit description so they can be rendered using MDX
2024-06-20 17:00:31 +03:00
allegroai
c0d687e2ef Fix missing git in Dockerfile for building webapp 2024-03-28 17:50:35 +02:00
allegroai
9c95c63ce0 Version bump to v1.15.0 2024-03-24 11:25:05 +02:00
allegroai
73179f53c2 Use latest patch versions for ES and Mongo 2024-03-24 11:24:51 +02:00
allegroai
ddc8a76279 Set API version to v2.29 2024-03-18 16:02:45 +02:00
allegroai
ac7ea0d477 Allow filtering task models.input.model field by array of ids 2024-03-18 16:01:45 +02:00
allegroai
3544ed19f8 Use latest patch versions for mongodb and ES 2024-03-18 15:59:15 +02:00
allegroai
5e68f053a0 Add widgets link in nginx configuration 2024-03-18 15:58:19 +02:00
allegroai
7bd5fdad59 Update webserver build: allow using external configuration from a file or from environment variables 2024-03-18 15:57:19 +02:00
allegroai
484c72aa0c Upgrade to Debian bookworm 2024-03-18 15:56:18 +02:00
allegroai
2027afbed5 Added missing ES index template for scalar events 2024-03-18 15:54:38 +02:00
allegroai
7d649f1964 Support controlling config value inheritance from the base folder 2024-03-18 15:53:58 +02:00
allegroai
8d237b3cae Upgrade Redis to v6.2 2024-03-18 15:53:07 +02:00
allegroai
e8ee6ce72e Code cleanup 2024-03-18 15:52:22 +02:00
allegroai
5749ff0454 Add security headers to webserver 2024-03-18 15:50:40 +02:00
allegroai
5189adf4f1 Improve handling of fixed users 2024-03-18 15:49:42 +02:00
allegroai
92a4e56c1f Add support for cookies extensions 2024-03-18 15:46:07 +02:00
allegroai
33528870ae Request cookies processing enhanced for more flexibility 2024-03-18 15:45:09 +02:00
allegroai
85f5b8b6f6 Fix last metrics for task are updated for events reported without variants 2024-03-18 15:44:28 +02:00
allegroai
6112910768 Make sure that legacy templates are deleted and empty db check is done on the new templates 2024-03-18 15:40:13 +02:00
allegroai
d3013ac285 Invalidate token on user logoff 2024-03-18 15:38:44 +02:00
allegroai
88abf28287 Add ElasticSearch 8.x support 2024-03-18 15:37:44 +02:00
allegroai
6a1fc04d1e Set cookies SameSite value to Lax 2024-02-13 16:18:21 +02:00
allegroai
ee8eb03698 Fix crash when importing events for public company tasks 2024-02-13 16:17:52 +02:00
allegroai
5799baae45 Make sure that APIs that aggregate task/model data from projects can be called for the root project 2024-02-13 16:17:33 +02:00
allegroai
801e536c5e Fix tasks.started to correctly handle null values in the started field 2024-02-13 16:17:02 +02:00
allegroai
6e484ea8f4 Fix missing region parameter when deleting files from minio server 2024-02-13 16:16:24 +02:00
allegroai
a47e65d974 Add input parameters check to multiple APIs 2024-02-13 16:15:55 +02:00
allegroai
702b6dc9c8 Version bump to v1.14.0 2024-01-10 15:31:11 +02:00
allegroai
db15f235e4 Make sure files downloaded from the apiserver are not cached by browsers 2024-01-10 15:31:01 +02:00
allegroai
8c347f8fa9 Fix include and exclude filters not processing "no tags" condition 2024-01-10 15:26:55 +02:00
allegroai
768c3d80ff Remove callback_url_prefix and state parameters from login.supported_modes and does not return urls 2024-01-10 15:26:22 +02:00
allegroai
a5c3ef6385 Fix query filter so that the default operator between different query operations on the same parameter is AND instead of OR 2024-01-10 15:24:37 +02:00
allegroai
11b7a384af Set API version 2.28 2024-01-10 15:23:54 +02:00
allegroai
9a70ade4a6 Support task models with missing model field in data_tool import 2024-01-10 15:18:58 +02:00
allegroai
91ce140901 Add "queue watched" indication to pipelines.start_pipeline 2024-01-10 15:15:43 +02:00
allegroai
49084a9c49 Optimize task statistics for projects dashboard and statistics reporter 2024-01-10 15:13:25 +02:00
allegroai
8a99eb6812 Fix model_metrics parameter name in get_multi_task_metrics schema 2024-01-10 15:12:56 +02:00
allegroai
811ab2bf4f Support exporting users with data tool 2024-01-10 15:12:07 +02:00
allegroai
3752db122b Add events.get_multi_task_metrics 2024-01-10 15:11:27 +02:00
allegroai
439911b84c Upgrade werkzeug and flask dependencies 2024-01-10 15:10:46 +02:00
allegroai
262a301e28 Check for dictionary type for some model and task fields 2024-01-10 15:10:41 +02:00
allegroai
a604451b01 Refactor check for tasks write permission 2024-01-10 15:08:20 +02:00
allegroai
88a7773621 Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots 2024-01-10 15:07:46 +02:00
allegroai
35c4061992 Support filtering by task or model ids in projects.get_unique_metric_variants 2024-01-10 15:06:21 +02:00
allegroai
4684fd5b74 Version bump to v1.13.0 2023-11-17 09:49:26 +02:00
allegroai
e08123fcc0 Fix workers.activity_report should return 0s for the time when no workers reported 2023-11-17 09:49:18 +02:00
allegroai
e713e876eb Upgrade urllib3 requirement 2023-11-17 09:48:19 +02:00
allegroai
c2cc788319 Added supported API versions doc 2023-11-17 09:47:44 +02:00
allegroai
da8315d0db Allow queries on the list of execution queue ids in tasks.get_all/get_all_ex 2023-11-17 09:47:19 +02:00
allegroai
4ac6f88278 Optimize Workers retrieval
Store worker statistics under worker id and not internal redis key
Fix unit tests
2023-11-17 09:46:44 +02:00
allegroai
a7865ccbec Turn on async task events deletion in case there are more than 100_000 events 2023-11-17 09:45:55 +02:00
allegroai
ec14f327c6 Optimize endpoints that do not require authorization by not validating JWT token 2023-11-17 09:45:22 +02:00
allegroai
a03b24d6b6 Add log info on caller IP if token validation fails 2023-11-17 09:43:59 +02:00
allegroai
cb71ef8e47 Fix missing scroll_id in events.get_scalar_metric_data 2023-11-17 09:43:11 +02:00
allegroai
8678fbc995 Fix properly unset Task fields on task reset 2023-11-17 09:42:39 +02:00
allegroai
58df8f201a Update API to 2.27 2023-11-17 09:40:34 +02:00
allegroai
f4bf16c156 Fix schema for swagger compatibility 2023-11-17 09:39:52 +02:00
allegroai
942f996237 Fix async_delete cannot be configured using configuration files 2023-11-17 09:39:22 +02:00
allegroai
c1e7f8f9c1 Optimize deletion of projects with many tasks 2023-11-17 09:38:32 +02:00
allegroai
274c487b37 Add update_tags api to tasks and models 2023-11-17 09:37:25 +02:00
allegroai
cc0129a800 Add filters parameter for passing user defined list filters for all get_all_ex apis 2023-11-17 09:36:58 +02:00
allegroai
388dd1b01f Fix regression issue with archive tasks display 2023-11-17 09:35:55 +02:00
allegroai
d62ecb5e6e Add last_change and last_change_by DB Model 2023-11-17 09:35:22 +02:00
allegroai
6d507616b3 Add pattern parameter to projects.get_hyperparam_values 2023-11-17 09:34:13 +02:00
allegroai
d0252a6dd9 Make sure that hyperparam/configuration/metadata keys that are contain only empty space are rejected 2023-11-17 09:32:22 +02:00
allegroai
2263e7cc1e Fix regression with archive tasks display 2023-07-31 14:16:08 +03:00
allegroai
81b93e6811 Updated dependency - dnspython is a required dependency of pymongo as of pymongo v4.3 (https://pymongo.readthedocs.io/en/stable/changelog.html#changes-in-version-4-3-4-3-2) 2023-07-27 11:49:40 +03:00
allegroai
491e83d0f1 Version bump to v1.12.0 2023-07-26 18:56:04 +03:00
allegroai
f84cc0a2cb Remove 10 metrics limit in multi-task plot comparison 2023-07-26 18:55:49 +03:00
allegroai
6c5f966ed4 Add new_status field to tasks.dequeue and dequeue_many endpoints 2023-07-26 18:55:05 +03:00
allegroai
4eff657810 Fix debug images not returned for tasks in new db 2023-07-26 18:54:19 +03:00
allegroai
74acaa31df Add explicit refresh interval to ES mappings
Fix queue tests
2023-07-26 18:54:02 +03:00
allegroai
21ed8559bf Fix worker keys not returned in queues.get_all_ex 2023-07-26 18:51:20 +03:00
allegroai
3927604648 Add task names to events.get_single_value_metrics endpoint response 2023-07-26 18:50:53 +03:00
allegroai
f7dcbd96ec Fix deleting model events
Add delete_external_artifacts parameter to projects.delete endpoint
2023-07-26 18:49:54 +03:00
allegroai
5950b81f0b Fix child tasks count for top level pipeline and dataset projects 2023-07-26 18:49:12 +03:00
allegroai
1e51e2e221 Allow projection of more than 500 items 2023-07-26 18:46:58 +03:00
allegroai
4c98b87554 Fix issues with new dependencies 2023-07-26 18:46:28 +03:00
allegroai
c196043d2a Add max_download_items to users.get_current_user endpoint response 2023-07-26 18:45:42 +03:00
allegroai
752020c66a Update API version to 2.26 2023-07-26 18:44:20 +03:00
allegroai
6885d07462 Write UTF-8 BOM into csv download file 2023-07-26 18:43:38 +03:00
allegroai
00552da1b0 Requests context is not needed any more 2023-07-26 18:43:09 +03:00
allegroai
eebe2eeffc Update requirements 2023-07-26 18:42:26 +03:00
allegroai
bc2fe28bdd Add field_mappings to organizations download endpoints 2023-07-26 18:39:41 +03:00
allegroai
ed86750b24 Add scalar field type to jsonmodels 2023-07-26 18:39:06 +03:00
allegroai
6df69afb25 Support "__$or" condition on projects children filtering 2023-07-26 18:38:41 +03:00
allegroai
3f22423c3f Support paging in projects.get_model_metadata_values and get_hyperparam_values endpoints 2023-07-26 18:38:11 +03:00
allegroai
3ad636c468 Exported csv file name now contains the project name (including non-ascii names) 2023-07-26 18:37:20 +03:00
allegroai
5c80336aa9 Project delete and validate_delete now analyses and presents info for datasets and pipelines 2023-07-26 18:36:45 +03:00
allegroai
5cd59ea6e3 Fix csv export handling "," in fields 2023-07-26 18:35:31 +03:00
allegroai
5d3ba4fa73 Fix events.get_multitask_plots to retrieve last iterations per each task metric separately 2023-07-26 18:34:30 +03:00
allegroai
42556c8dbb Pipelines children query now looks for pipeline projects and not tasks 2023-07-26 18:33:35 +03:00
allegroai
dbe1c6f00f Allow configuring multi-plots batch size 2023-07-26 18:33:10 +03:00
allegroai
a17485b1bd Allow dequeueing a deleted task 2023-07-26 18:32:32 +03:00
allegroai
a2b9fed92d Make sure that scroll parameters are ignored when downloading tasks 2023-07-26 18:31:56 +03:00
allegroai
ff34da3c88 Add organization.download_for_get_all endpoint 2023-07-26 18:31:20 +03:00
allegroai
5239755066 Support include_subprojects flag in reports.get_all_ex endpoint 2023-07-26 18:30:34 +03:00
allegroai
8061dfedbb Fix NewListBucketsHelper backwards compatibility 2023-07-26 18:27:51 +03:00
allegroai
011164ce9b Support __$and condition for excluded terms in get_all_ex endpoints list filters 2023-07-26 18:26:49 +03:00
allegroai
8135cf5258 Add include_subprojects to tasks/models.get_all endpoints
Fix escaping metadata for tasks, models and queues
2023-07-26 18:24:49 +03:00
allegroai
a83a932e84 Add pipelines.delete_runs endpoint 2023-07-26 18:23:05 +03:00
allegroai
db021f2863 Add workers.get_count endpoint 2023-07-26 18:21:52 +03:00
allegroai
1b650b1689 Add projects.get_user_names endpoint 2023-07-26 18:21:16 +03:00
allegroai
14d18a7aba Remove obsolete duration field 2023-07-26 18:19:41 +03:00
Olivier Girardot
a7ed46979f Fix handling of the subpaths with nginx templating (#204)
Co-authored-by: ogirardot <olivier.girardot@malt.com>
2023-07-02 16:12:29 +03:00
allegroai
452f606889 Version bump to v1.11 2023-05-25 19:40:07 +03:00
allegroai
fc47ccbf09 Add default services agent user 2023-05-25 19:39:53 +03:00
allegroai
0206811342 Improve empty database check during startup 2023-05-25 19:39:17 +03:00
allegroai
a3ac1049a3 Update ClearML SDK dependency 2023-05-25 19:38:48 +03:00
allegroai
8488f63a3a Add fileserver URL prefixes for async deletion 2023-05-25 19:38:07 +03:00
allegroai
9206a7c57d Schedule external file URLs for deletion on models deletion 2023-05-25 19:36:28 +03:00
allegroai
0c37ced2a1 Fix model Id handling when deleting models for tasks 2023-05-25 19:35:18 +03:00
allegroai
b22f26129e Update requirements 2023-05-25 19:34:19 +03:00
allegroai
d8b998ebd8 Bump API version to 2.25 2023-05-25 19:33:37 +03:00
allegroai
741fa84b52 Fix projects own_tasks does not take task state filter into account 2023-05-25 19:32:52 +03:00
allegroai
d9579891c8 Return only reports from the .reports projects in reports.get_all_ex 2023-05-25 19:31:05 +03:00
allegroai
900414d0de Add option to echo ping payload 2023-05-25 19:30:13 +03:00
allegroai
5449b332d2 Support reports from the root project in reports.get_all_ex 2023-05-25 19:29:46 +03:00
allegroai
875f4b9536 Fix task dequeue will changes status for un-queued/running tasks 2023-05-25 19:28:49 +03:00
allegroai
95b8f22899 Add CLEARML_FILES_HOST to async_delete in windows 2023-05-25 19:27:40 +03:00
allegroai
4058fb9ce5 Migrate to python 3.9 bullseye docker images
Update Mongo driver version
2023-05-25 19:27:14 +03:00
allegroai
cf8e847ed3 Switch to new redis version 2023-05-25 19:22:39 +03:00
allegroai
755cc803d9 Add remove_from_all_queues parameter to tasks.dequeue/dequeue_many endpoints 2023-05-25 19:22:10 +03:00
allegroai
3729afe014 Optimize queues.get_next_task to retrieve required task fields only 2023-05-25 19:21:24 +03:00
allegroai
dff2ed34e8 Support receiving mixed events for both locked and unlocked tasks and models events.add_batch 2023-05-25 19:20:35 +03:00
allegroai
de9651d761 Allow mixing Model and task events in the same events batch 2023-05-25 19:19:45 +03:00
allegroai
818496236b Support filtering by children tags in projects.get_all_ex 2023-05-25 19:19:10 +03:00
allegroai
e99817b28b Task reports can now return single value metrics 2023-05-25 19:18:24 +03:00
allegroai
58465fbc17 Model events are fully supported 2023-05-25 19:17:40 +03:00
allegroai
2e4e060a82 Task move forward/backwards in queue is now atomic 2023-05-25 19:16:33 +03:00
allegroai
5c5d9b6434 Fix numeric hyperparam values are not sorted lexicographically with descending sort order 2023-05-25 19:15:59 +03:00
allegroai
4291ad682a Support filtering by task name in projects.get_task_parent 2023-05-25 19:15:26 +03:00
allegroai
4c22757002 Fix task that is not in queue but has 'queued' status can't be dequeued 2023-05-25 19:14:25 +03:00
allegroai
6e777e80b8 Cleaned up unit tests 2023-05-25 19:13:10 +03:00
allegroai
c8e4d9eeac Fix Dockerfile uses deprecated base image 2023-04-18 10:50:13 +03:00
dependabot[bot]
b51aa5c29b Bump redis from 3.5.3 to 4.4.4 in /apiserver (#190)
Bumps [redis](https://github.com/redis/redis-py) from 3.5.3 to 4.4.4.
- [Release notes](https://github.com/redis/redis-py/releases)
- [Changelog](https://github.com/redis/redis-py/blob/master/CHANGES)
- [Commits](https://github.com/redis/redis-py/compare/3.5.3...v4.4.4)

---
updated-dependencies:
- dependency-name: redis
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-01 08:59:02 +03:00
allegroai
e7c9daa42b Fix get_task_events to correctly use last_iters for model events 2023-03-28 16:45:44 +03:00
allegroai
7357654249 Version bump to v1.10 2023-03-23 19:17:00 +02:00
allegroai
a6f671b46a Fix typo 2023-03-23 19:16:38 +02:00
allegroai
17a8b440bd Fix only last event of each type is stored per model (all should be stored) 2023-03-23 19:16:30 +02:00
allegroai
eb2b9cbd9a Fix project count for datasets and pipelines 2023-03-23 19:15:42 +02:00
allegroai
797e503e67 Update ES version 2023-03-23 19:14:33 +02:00
allegroai
30cfdac8f2 Fix project preview completed_tasks_24h should not count tasks that are marked as failed or running 2023-03-23 19:13:52 +02:00
allegroai
24bb87aaee Turn on mongo sorting using disk usage by default for sorting in *.get_all* apis 2023-03-23 19:12:52 +02:00
allegroai
dd49ba180a Improve statistics on projects children 2023-03-23 19:11:45 +02:00
allegroai
bda903d0d8 Set API version to 2.24 2023-03-23 19:11:13 +02:00
allegroai
9739eb2d5a Add report_assets field to report tasks 2023-03-23 19:09:03 +02:00
allegroai
cfbb37238f Add default workers timeout to the server's configuration 2023-03-23 19:08:11 +02:00
allegroai
6664c6237e Support querying by children_type in projects.get_all_ex 2023-03-23 19:07:42 +02:00
allegroai
74200a24bd Add filtering on child projects in projects.get_all_ex 2023-03-23 19:06:49 +02:00
john-zielke-snkeos
2fb9288a6c Add env switch to disable nginx ipv6 bind (#165) 2023-03-13 16:05:43 +02:00
shyallegro
5d014d81af Fix #184 and update docker build to include widgets (#185) 2023-03-07 11:26:12 +02:00
allegroai
3a2675abe1 Version bump to v1.9.2 2023-01-24 16:11:21 +02:00
allegroai
f0d68b1ce9 Make sure model label values are integer 2023-01-24 16:11:12 +02:00
allegroai
15db9cdaef Allow updating comments on published reports 2023-01-24 14:40:32 +02:00
Mal Miller
a45d47f5d7 Fix default value of CLEARML_AGENT_UPDATE_VERSION for agent-services (#114) 2023-01-03 13:45:52 +02:00
allegroai
b1a50c1370 Version bump to v1.9.1 2023-01-03 12:16:07 +02:00
allegroai
22a2a02760 Allow renaming published reports 2023-01-03 12:15:44 +02:00
allegroai
ab798e4170 Allow updating tags on published reports 2023-01-03 12:15:02 +02:00
allegroai
f09ac672d2 Add pipeline test 2023-01-03 12:14:12 +02:00
allegroai
2149b76f63 Fix crash when starting pipeline 2023-01-03 12:13:48 +02:00
allegroai
d96420aa67 Version bump to v1.9 2022-12-21 18:47:03 +02:00
allegroai
ed6c7b7bcb Fix Project time is not updated when moved or merged 2022-12-21 18:46:53 +02:00
allegroai
a392bc0bd7 Bump API version to 2.23 2022-12-21 18:46:12 +02:00
allegroai
7e97ec5555 Fix events.get_task_plots endpoint 2022-12-21 18:45:17 +02:00
allegroai
9c41124b81 Add support for moving objects to projects root 2022-12-21 18:43:45 +02:00
allegroai
14ff639bb0 Removed limit on event comparison for the same company tasks only 2022-12-21 18:42:40 +02:00
allegroai
e66257761a Add support for server-side delete for AWS S3, Google Storage and Azure Blob Storage 2022-12-21 18:41:16 +02:00
allegroai
0ffde24dc2 Add min and max value iteration to last metrics 2022-12-21 18:36:50 +02:00
allegroai
d4fdcd9b32 Upgrade mongoengine version 2022-12-21 18:35:23 +02:00
allegroai
18570bfccb Add project_id response field to reports.create endpoint 2022-12-21 18:35:14 +02:00
allegroai
54ce6c34c6 Fix bad field values might cause ugly server exception to be returned 2022-12-21 18:33:28 +02:00
allegroai
ae4c33fa0e Add support for allow_public flag in get_all_ex endpoint
Add `last_changed_by` field on task updates
Fix reports support
2022-12-21 18:32:56 +02:00
allegroai
c7cd949fd0 Add reports support
Fix schema
2022-12-21 18:30:54 +02:00
allegroai
1ce4058157 Change tasks comparison limit to 100 2022-12-21 18:29:49 +02:00
allegroai
7b6f24b24d Version bump to 1.8.0 2022-11-29 17:50:32 +02:00
allegroai
d03a931d84 Remove buggy debug code 2022-11-29 17:50:17 +02:00
allegroai
5cc7199661 Fix clearing ES scroll 2022-11-29 17:44:31 +02:00
allegroai
6537e9ef69 Add active_users and search_hidden options to get_entities_count endpoint 2022-11-29 17:44:19 +02:00
allegroai
930aaff791 Fix test 2022-11-29 17:43:43 +02:00
allegroai
1999fb2479 Add URL async delete prefixes setting 2022-11-29 17:43:08 +02:00
allegroai
9db14cc31d Update endpoint API versions 2022-11-29 17:42:19 +02:00
allegroai
e3cc689528 Enhance async_urls_delete feature with max_async_deleted_events_per_sec setting and fileserver timeout and prefixes 2022-11-29 17:41:49 +02:00
allegroai
9e0adc77dd Make sure id field is included in ordering 2022-11-29 17:40:09 +02:00
allegroai
58d9a64537 Fix endpoint name in schema 2022-11-29 17:39:35 +02:00
allegroai
d397d2ae20 Support optional async events deletion when deleting tasks 2022-11-29 17:38:41 +02:00
allegroai
2d711e1500 Collect model event URLs during task and project cleanup 2022-11-29 17:38:03 +02:00
allegroai
97992b0d9e Support returning multiple plots during history navigation 2022-11-29 17:37:30 +02:00
allegroai
bc23f1b0cf Add "queue watched" indication for tasks.enqueue and tasks.enqueue_many 2022-11-29 17:36:41 +02:00
allegroai
6b3eff1426 Support workers system tags 2022-11-29 17:35:25 +02:00
allegroai
caaf801cd0 Support plots navigation by iteration 2022-11-29 17:34:57 +02:00
allegroai
c23e8a90d0 Support model events 2022-11-29 17:34:06 +02:00
allegroai
fa5b28ca0e Support HTTP URLs when deleting fileserver files 2022-11-29 17:33:18 +02:00
allegroai
bfb55a9463 Support deleting external artifacts when deleting projects 2022-11-29 17:32:41 +02:00
allegroai
37e485e1f2 Upgrade API version to 2.22 2022-11-29 17:29:57 +02:00
allegroai
3451ff441f Fix projects.get_all_ex with active_users filtering did not work if the project id was passed as a string and not list 2022-11-29 17:27:54 +02:00
allegroai
53c9b5525e Add preliminary support for datasets under projects 2022-11-29 17:27:02 +02:00
PSL
e5230edac3 Add mongo username and password authentication (#162)
Co-authored-by: pangshaoliang <pangshaoliang@megvii.com>
2022-10-08 15:56:02 +03:00
allegroai
a54dd8030c Add async-delete to docker-compose 2022-09-29 19:44:05 +03:00
allegroai
482a5c34bc Version bump to 1.7.0 2022-09-29 19:40:20 +03:00
allegroai
ee2a72c70f Add company/uri index 2022-09-29 19:40:05 +03:00
allegroai
a0d8aaf3b9 Fix urls are not unquoted in batch_delete 2022-09-29 19:39:02 +03:00
allegroai
de1f823213 Removed stub timing context 2022-09-29 19:37:15 +03:00
allegroai
0c9e2f92ee Add server-side support for deleting files from fileserver on task delete 2022-09-29 19:34:24 +03:00
allegroai
6c49e96ff0 Update API version to 2.21 2022-09-29 19:31:42 +03:00
allegroai
81e3fc6577 Improve utilities 2022-09-29 19:30:57 +03:00
allegroai
e6dc4b7557 Set cloned task parent to original task if original task has no parent 2022-09-29 19:30:13 +03:00
allegroai
238a47a197 Add created field to backend.user 2022-09-29 19:29:36 +03:00
allegroai
04e7076628 Add support for specifying a specific task ID in queues.get_next_task 2022-09-29 19:27:42 +03:00
allegroai
0531612bf4 Fix deleting a queue should dequeue all enqueued tasks 2022-09-29 19:27:09 +03:00
allegroai
3ae410a1e9 Remove the ThreadsManager.terminating flag 2022-09-29 19:23:26 +03:00
allegroai
98ed3075dd Added exclude support when converting mongo objects to dictionary 2022-09-29 19:21:28 +03:00
allegroai
b871bf4224 Fix projects.get_all_ex 2022-09-29 19:20:50 +03:00
allegroai
8d4c02fc3c Add support for hidden internal queues 2022-09-29 19:20:24 +03:00
allegroai
b986980c75 Use correct attrs version, remove related dependency 2022-09-29 19:18:38 +03:00
allegroai
a4fa567be2 Fix task stats update 2022-09-29 19:18:22 +03:00
allegroai
ddb91f226a Add Task Unique Metrics to task object 2022-09-29 19:16:56 +03:00
allegroai
7772f47773 Support datetime ranges in field queries 2022-09-29 19:15:50 +03:00
allegroai
9c118d14e0 Add missing last_update field to models.get_all APIs 2022-09-29 19:14:13 +03:00
allegroai
efd56e085e Fix threaded jobs management (invoke only from AppSequence) 2022-09-29 19:13:22 +03:00
allegroai
4dff163af4 Improve examples pre-population code 2022-09-29 19:11:21 +03:00
allegroai
242a78a0fe Add request logging using the CLEARML_SERVER_DEBUG_REQUESTS env var 2022-09-29 19:10:55 +03:00
allegroai
78989fea91 Add better mongodb connection string verbosity 2022-09-29 19:10:13 +03:00
allegroai
5de7c12062 Version bump to v1.6 2022-07-08 18:05:43 +03:00
allegroai
3f79c19079 Add v1.6.0 mongodb migration 2022-07-08 18:05:33 +03:00
allegroai
fe29743c54 Add support for new IDs generation when importing projects 2022-07-08 18:04:40 +03:00
allegroai
d760cf5835 Remove use of dpath in query projection 2022-07-08 18:04:02 +03:00
allegroai
3695f25a5f Fix internal error returned to clients 2022-07-08 18:03:38 +03:00
allegroai
c6f1beafdd Update API version to 2.20 2022-07-08 18:02:44 +03:00
allegroai
68a54c34f3 Add user creation time to users.get_current_user 2022-07-08 17:59:45 +03:00
allegroai
ab495ae586 Fix archived projects handling 2022-07-08 17:55:02 +03:00
allegroai
b058770af1 Fix handling of empty hyperparam/configuration keys 2022-07-08 17:54:19 +03:00
allegroai
f7e833bf6f Fix loading expired worker entries from Redis 2022-07-08 17:53:12 +03:00
allegroai
36b9ab0453 Fix handling of empty keys in query 2022-07-08 17:51:49 +03:00
allegroai
ec0436d0da Fix task cleanup 2022-07-08 17:50:49 +03:00
allegroai
0f6c4e75b7 Fix debug images URL handling and task routing 2022-07-08 17:50:26 +03:00
allegroai
a41ae112a1 Fix backward compatibility when importing old projects 2022-07-08 17:49:36 +03:00
allegroai
c28f478ea8 Fix worker Id is used instead of worker key when processing report 2022-07-08 17:48:17 +03:00
allegroai
c18eb99d06 Return getting_started_info in users.get_current_user 2022-07-08 17:45:33 +03:00
allegroai
3a60f00d93 Add support for Dataset projects 2022-07-08 17:45:03 +03:00
allegroai
ee87778548 Better support for PyJWT 2.4 2022-07-08 17:44:17 +03:00
allegroai
52c0c4d438 Add model task cleanup on tasks.reset 2022-07-08 17:43:54 +03:00
allegroai
d117a4f022 Support max_task_entries option in queues.get_by_id/get_all
Add queues.peek_task and queues.get_num_entries
2022-07-08 17:42:20 +03:00
allegroai
6683d2d7a9 Fix task cleanup 2022-07-08 17:40:55 +03:00
allegroai
05357fe25e Support publish option in tasks.completed 2022-07-08 17:40:43 +03:00
allegroai
adc1825843 Add support for model statistics 2022-07-08 17:39:41 +03:00
allegroai
0c15169668 Improve tests 2022-07-08 17:38:31 +03:00
allegroai
123dc1dcfb Improve query error handling 2022-07-08 17:38:22 +03:00
allegroai
b2feafac09 Support workers filtering with tags 2022-07-08 17:37:33 +03:00
allegroai
b41ab8c550 Better support for queue metrics and queue metrics refresh on sample 2022-07-08 17:36:46 +03:00
allegroai
62d5779bd5 Count own tasks/models for projects 2022-07-08 17:35:01 +03:00
allegroai
f8b9d9802e Add support for organization.get_entities_count 2022-07-08 17:32:56 +03:00
allegroai
dd8a1503b0 Add support for navigate_current_metric in events.get_debug_image_sample 2022-07-08 17:31:44 +03:00
allegroai
cff98ae900 Add support for events.get_task_single_value_metrics, events.plots, events.get_plot_sample and events.next_plot_sample 2022-07-08 17:29:39 +03:00
allegroai
9b108740da Bump PyJWT version due to "Key confusion through non-blocklisted public key formats" vulnerability 2022-05-25 16:50:19 +03:00
allegroai
08a7bc7c9f Fix not all the event logs are returned from sharded ES 2022-05-20 15:11:05 +03:00
allegroai
fb256d7e5b Version bump to v1.5 2022-05-18 15:29:45 +03:00
allegroai
710443b078 Fix move task to trash is not thread-safe 2022-05-18 10:31:20 +03:00
allegroai
e0cde2f7c9 Add support for deleting pipeline projects 2022-05-18 10:30:21 +03:00
allegroai
60b9c8de14 Allow arbitrary task fields in project statistics filter 2022-05-18 10:29:36 +03:00
allegroai
ecffe26be4 Fix auth.edit_credentials 2022-05-18 10:28:58 +03:00
allegroai
2570bd9e26 Fix ES issue with capital letters in index name 2022-05-18 10:18:23 +03:00
allegroai
174f84514a Fix no destination when merging projects 2022-05-18 10:17:34 +03:00
allegroai
65cb8d7b43 Refactor method name 2022-05-18 10:16:41 +03:00
allegroai
5f8ef808a3 Fix ES issue with capital letters in index name 2022-05-18 10:16:19 +03:00
allegroai
4941ac70e0 Add events.clear_task_log 2022-05-17 16:09:23 +03:00
allegroai
67cd461145 Add auth.edit_credentials 2022-05-17 16:08:12 +03:00
allegroai
92b5fc6f9a Fix handling hidden sub-projects 2022-05-17 16:06:34 +03:00
allegroai
b90165b4e4 Support queue_name in tasks enqueue 2022-05-17 16:04:34 +03:00
allegroai
6c2dcb5c8a Improve error message 2022-05-17 15:56:18 +03:00
allegroai
3efed32934 Add X-Jwt-Payload to redacted headers 2022-05-17 15:55:41 +03:00
allegroai
69737308fe Version bump to v1.4.0 2022-04-18 16:38:22 +03:00
allegroai
a6dbea808a Add indices for task.last_update and task.status_changed 2022-04-18 16:37:22 +03:00
allegroai
5131b17901 Support not returning hidden sub-projects when include_stats is specified without search_hidden 2022-04-18 16:36:14 +03:00
allegroai
5f21c3a56d Add support for searching hidden projects and tasks 2022-04-18 16:34:18 +03:00
allegroai
2350ac64ed Fix internal error on count task events if there is no events index 2022-04-18 16:31:02 +03:00
allegroai
d146127c18 Add events.clear_scroll endpoint to clear event search scrolls 2022-04-18 16:29:57 +03:00
Mal Miller
abd65e103e Ensure agent-services waits for API server to be ready (#129) 2022-03-31 11:10:45 +03:00
pollfly
bf65ea7bd0 Resize admonitions (#126) 2022-03-27 15:04:43 +03:00
pollfly
73e278a8ed Add deprecation notes to legacy docs (#124) 2022-03-23 23:51:55 +02:00
Zied ANDOLSI
d92dfbbdb7 Allow ClearML to be served with a URL path prefix (#121)
* add server root url

* [Feature Request] Add proxy_pass for root url other than /

* [Feature Request] Add proxy_pass for root url other than /

* add support for web sub path

* add support for web sub path

* use default conf instead of created a custom one

* code reivew: move cp command in if block

* Add commented env var in the docker-compose file

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-22 17:21:58 +02:00
Zied ANDOLSI
5c1e419eb5 Allow overriding clearml web git url on build (#122)
* add server root url

* [Feature Request] Add possibility to override clearml web git url

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-17 14:35:50 +02:00
allegroai
124684f53f Version bump to v1.3.0 2022-03-15 16:34:35 +02:00
allegroai
455b5d6758 Fix pre-populate to convert model metadata from the old format 2022-03-15 16:30:14 +02:00
allegroai
c04e2e498b Support credentials label and last_used_from fields 2022-03-15 16:29:37 +02:00
allegroai
da8a45072f Add pipelines support 2022-03-15 16:28:59 +02:00
allegroai
e1992e2054 Fix queue metrics calculation 2022-03-15 16:28:49 +02:00
allegroai
c17cedd93a Support disabling response compression in fileserver 2022-03-15 16:27:31 +02:00
allegroai
b6ad8f8790 Add support for worker auto-unregister (instead of raising an error) 2022-03-15 16:25:14 +02:00
allegroai
5acc7eebc3 Set API version to 2.17 2022-03-15 16:22:51 +02:00
allegroai
941927dfcd Return fixed fileserver header 2022-03-15 16:21:52 +02:00
allegroai
02933a9c93 Support disabling response compression
Return fixed server header
2022-03-15 16:21:14 +02:00
allegroai
e537651f29 Better support for assets upload/download 2022-03-15 16:19:52 +02:00
allegroai
af09fba755 Add metadata dict support for models, queues
Add more info for projects
2022-03-15 16:18:57 +02:00
Reuben Morais
04ea9018a3 Add missing g++ dep to server build (#111) 2022-02-21 22:14:22 +02:00
allegroai
ff7e1be24f Updated docker-compose files for v1.2.0 2022-02-14 15:27:23 +02:00
allegroai
fc4fd9e61c Version bump to v1.2.0 2022-02-14 15:26:27 +02:00
allegroai
8908c7dcf9 Update driver requirements
Refactor ES initialization
2022-02-13 20:27:12 +02:00
allegroai
b9996e2c1a Protect against multiple connects to the update server from different processes
Code cleanup
2022-02-13 20:12:12 +02:00
allegroai
afdc56f37c Use task active duration for worker task running time 2022-02-13 20:01:47 +02:00
allegroai
a25cd5dae8 Fix version conflicts when deleting task events cause an error 2022-02-13 20:01:25 +02:00
allegroai
447adb9090 Add support for credentials label
Support no_scroll in events.get_task_plots
Support better project stats
Fix Redis required on mongodb initialization
Update tests
2022-02-13 19:59:58 +02:00
allegroai
92fd98d5ad Add support for lists and nested fields in URL args and form 2022-02-13 19:52:05 +02:00
allegroai
c4001b4037 Add Redis cluster support
Fix for lru_cache usage
2022-02-13 19:48:26 +02:00
allegroai
970a32287a Add Redis password support 2022-02-13 19:37:52 +02:00
allegroai
17cd48dada Add support for override cookie domains
Support for community invitation alarms
Remove duplicate property
Add query optimizations
2022-02-13 19:35:35 +02:00
allegroai
ea3b6e955f Optimize nested_get() 2022-02-13 19:32:22 +02:00
allegroai
843450bb9b Fix add_or_update_artifacts should always be allowed on in_progress tasks
Fix delete_artifacts should always be allowed on in_progress tasks
Fix query code
2022-02-13 19:31:54 +02:00
allegroai
e149af58b1 Support for additional mata data in api call response 2022-02-13 19:30:36 +02:00
allegroai
604a38035b Add organization.update_company_name
Fix unit-tests
2022-02-13 19:29:46 +02:00
allegroai
cae38a365b Fix base query building
Fix schema
Improve events.scalar_metrics_iter_raw implementation
2022-02-13 19:28:23 +02:00
allegroai
e334246b46 Add support for project stats with children flag 2022-02-13 19:26:47 +02:00
allegroai
36e013b40c Add support for events.scalar_metrics_iter_raw 2022-02-13 19:26:03 +02:00
allegroai
f20cd6536e Add scroll support to *.get_* 2022-02-13 19:23:29 +02:00
allegroai
446bd35006 Refactor debug images response, model ORM 2022-02-13 19:21:07 +02:00
allegroai
a377a7e315 Support status_message and status_reason in tasks.delete 2022-02-13 19:20:31 +02:00
allegroai
3d046ac282 Fix project should not be merged into itself 2022-02-13 19:18:08 +02:00
allegroai
a08fa9a0e1 Add missing API Errors 2022-02-13 19:16:58 +02:00
allegroai
5856ed2836 Update Model.last_update on changes to tags and system tags 2022-02-13 19:15:37 +02:00
allegroai
d295355d99 Better logger name if called from __init__.py 2022-02-13 19:15:10 +02:00
pollfly
77350f6119 Fix link (#104) 2022-01-27 12:15:55 +02:00
Niels ten Boom
bc2c2ebbfd Add connection string functionality for MongoDB access (#102) 2022-01-08 12:07:59 +02:00
allegroai
1502e02a1a Update ES version to 7.16.2 2021-12-22 13:53:34 +02:00
allegroai
d0e2313a24 Update README regarding CVE-2021-45046 2021-12-15 15:51:18 +02:00
allegroai
d8ba1a8ea7 Fix README 2021-12-14 15:52:53 +02:00
allegroai
ca7937fc4e Fix README 2021-12-14 15:50:30 +02:00
allegroai
df89bcceef Update README with a note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:48:54 +02:00
allegroai
cfccbe05c1 Add precautionary mitigation for Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:15:11 +02:00
Théo Mathieu
e352a6a1e7 Fix elasticsearch authentication when initializing (#98) 2021-12-05 09:55:06 +02:00
Théo Mathieu
8a3d992aaf Support MongoDB SRV endpoints (#96) 2021-12-02 10:07:33 +02:00
allegroai
c37f3d8d5b Fix set() not supported in ConfigTree()
Add user/pass config support
2021-11-15 18:33:49 +02:00
allegroai
a96870e092 Add admonition in case only username or password were provided 2021-11-15 15:19:07 +02:00
allegroai
6bf1032237 Rename back to docker-compose.yml 2021-11-15 15:13:09 +02:00
Weixiao Huang
3d816c747d Add ES http_auth credentials support (#93)
Also update ES and MongoDB versions and fix nginx configuration bug

Co-authored-by: huangweixiao <huangweixiao@megvii.com>
2021-11-15 15:01:27 +02:00
Jake Henning
3f2b96266b Merge pull request #91 from valeriano-manassero/fix-dockerfile-chmod
Fix chmod for file copy in Dockerfile
2021-10-19 11:35:00 +03:00
Valeriano Manassero
22b16d12eb fix chmod for file copy 2021-10-19 09:06:51 +02:00
allegroai
c55b6f30df Add Dockerfile 2021-10-18 16:52:17 +03:00
allegroai
b7045d3d28 Fix docker-compose escaping 2021-10-18 16:49:51 +03:00
Jake Henning
e31a404885 Remove README mentions of demo server (#90) 2021-10-10 16:32:51 +03:00
Revital
643588b71a edit README mention of demo server 2021-10-10 11:27:44 +03:00
Jake Henning
a64c4d264d Merge pull request #82 from IgorKasianenko/IgorKasianenko-patch-1
Fix typo TRAINS > CLEARML for env variables in README
2021-08-12 11:47:20 +03:00
Igor Kasianenko
567780e188 Fix typo TRAINS > CLEARML for env variables 2021-08-11 16:21:02 +03:00
allegroai
1bc8529d83 Version bump 2021-08-05 16:46:29 +03:00
allegroai
6b480d7e87 Fix file server GET response for gzipped data-files contains Content-Encoding: gz header, causing clients to automatically decompress the file 2021-08-05 16:46:25 +03:00
allegroai
083fd315e9 Fix server error when running with non-migrated v0.16 ElasticSearch data 2021-08-05 16:46:05 +03:00
Jake Henning
ef20e76174 Update README with artifact.io badge 2021-07-27 19:53:41 +03:00
Jake Henning
8c8910808e Merge pull request #80 from pollfly/master
Fix README links
2021-07-27 12:58:45 +03:00
Revital
f6ad379310 link to clear.ml docs in readme, add image 2021-07-27 12:54:41 +03:00
allegroai
c5d6ce3e65 Version bump 2021-07-25 14:40:57 +03:00
allegroai
694dbc31c4 Fix incorrect ES query (merge issue) 2021-07-25 14:40:49 +03:00
allegroai
6488dc54e6 Better handling of stack trace report on 500 error 2021-07-25 14:39:59 +03:00
allegroai
158da9b480 Allow setting status_message in tasks.update
Optimizations and refactoring
2021-07-25 14:35:36 +03:00
allegroai
ec2e071ab7 Fix mongoengine cannot handle field name with leading or trailing "_" when used in fields query within get_all endpoints 2021-07-25 14:34:04 +03:00
allegroai
465e270342 Fix queued task is not dequeued on tasks.stop 2021-07-25 14:32:09 +03:00
allegroai
6705aff56f Allow requesting plots and iter_histograms for all variants 2021-07-25 14:30:38 +03:00
allegroai
9069cfe1da Support querying task events per specific metrics and variants 2021-07-25 14:29:41 +03:00
allegroai
677bb3ba6d Add force parameter to tasks.enqueue 2021-07-25 14:27:46 +03:00
allegroai
cb253cff9e Don't use special characters in secrets 2021-07-25 14:26:49 +03:00
allegroai
39ceb5ac5c Fix pre-populate logic to avoid overriding existing users 2021-07-25 14:26:31 +03:00
allegroai
d4edeaaf1b Add projects.validate_delete 2021-07-25 14:17:29 +03:00
allegroai
56aea1ffb8 Fix filtering on hyperparams (https://github.com/allegroai/clearml/issues/385, https://clearml.slack.com/archives/CTK20V944/p1626600582284700) 2021-07-25 13:55:09 +03:00
allegroai
09ab2af34c Version bump 2021-05-27 17:13:19 +03:00
allegroai
8bb26a6b0b Fix fileserver depends on deprecated flask._compat.fspath and safe_join 2021-05-27 17:13:02 +03:00
allegroai
3f2304549d Move new migrations to 1_0_2 2021-05-27 16:56:47 +03:00
allegroai
ad72a435f1 Clean Task runtime on reset 2021-05-27 16:56:03 +03:00
allegroai
f34332344e Fix Task container raises validation error on null values 2021-05-27 16:55:32 +03:00
allegroai
d324b57dd7 Fix bad error message format 2021-05-27 16:55:00 +03:00
allegroai
2216bfe875 Version bump 2021-05-11 16:12:48 +03:00
allegroai
9beefa7473 Add missing login.logout endpoint 2021-05-11 16:12:27 +03:00
allegroai
8ebc334889 Fix broken config dir backwards compatibility (/opt/trains/config should still be supported) 2021-05-11 16:12:13 +03:00
allegroai
e662c850af Update config file in docs 2021-05-04 11:07:38 +03:00
allegroai
1e5163e530 Upgrade jinja2 version due to CVE-2020-28493 2021-05-03 23:23:06 +03:00
allegroai
1567774765 Version bump 2021-05-03 18:20:32 +03:00
allegroai
babfcbb707 Update migration script 2021-05-03 18:15:43 +03:00
allegroai
027edd86bb Fix actual file path reported in error/success message 2021-05-03 18:14:56 +03:00
allegroai
cc83aadae6 Fix file delete (bad merge) 2021-05-03 18:14:30 +03:00
allegroai
8c18660a82 Fix inconsistency in accessing files between download and delete 2021-05-03 18:14:08 +03:00
allegroai
4fe61ee25c Fix running migration scripts calling other files 2021-05-03 18:13:49 +03:00
allegroai
e18b21639c Fix regex query for fields containing "_" 2021-05-03 18:13:00 +03:00
allegroai
1cef03b8c2 Add check_contents flag for projects.get_all_ex 2021-05-03 18:12:44 +03:00
allegroai
d60d6dfe99 Move to clearml in docker-compose files 2021-05-03 18:12:21 +03:00
allegroai
27d086bca2 Fix schema for Task.runtime
Add infrastructure for API calls limits handling
2021-05-03 18:11:46 +03:00
allegroai
add3f011a0 Add runtime to tasks.edit 2021-05-03 18:10:48 +03:00
allegroai
ee90b0b024 Remove "Auto-generated while cloning" project description 2021-05-03 18:10:32 +03:00
allegroai
9bf107866f Fix crash in models publish_many without model task 2021-05-03 18:10:09 +03:00
allegroai
4d2f282950 Add Model.last_update to schema 2021-05-03 18:09:54 +03:00
allegroai
b55fad1b59 Remove "Auto-generated during move" project description 2021-05-03 18:09:31 +03:00
allegroai
ba77ff11e9 Fix missing custom metric values turn up first in sorting 2021-05-03 18:08:39 +03:00
allegroai
b67aa05d6f Return results per task iterations in debug images request 2021-05-03 18:08:14 +03:00
allegroai
6b0c45a861 Fix batch operations results 2021-05-03 18:07:37 +03:00
allegroai
dc9623e964 Fix docker_cmd projection in backwards compatibility
Fix support to clear input/output models and docker_cmd in backwards compatibility mode
Fix schema
2021-05-03 18:06:39 +03:00
allegroai
3d73d60826 Better handling of invalid iterations on add_batch 2021-05-03 18:05:24 +03:00
allegroai
9f0c9c3690 Fix open ranges 2021-05-03 18:05:03 +03:00
allegroai
1a3d3494ce Fix numeric locale 2021-05-03 18:04:45 +03:00
allegroai
b99f620073 Added unarchive APIs 2021-05-03 18:04:17 +03:00
allegroai
e2f265b4bc Unify batch operations 2021-05-03 18:03:54 +03:00
allegroai
251ee57ffd Fix rapidjson dumps does not support ensure_ascii, only Encoder initialization does
Add task enqueue status
2021-05-03 18:03:17 +03:00
allegroai
7e03104f1c Add Model last_update field 2021-05-03 18:02:25 +03:00
allegroai
f1a258208e Disable backwards compatibility for 2.13 clients 2021-05-03 18:01:59 +03:00
allegroai
66cc49313b Fix schema 2021-05-03 18:01:29 +03:00
allegroai
9ae2943f7d Fix crash in tasks.reset 2021-05-03 17:59:44 +03:00
allegroai
54326f707b Add JSON flags support to APICall 2021-05-03 17:58:57 +03:00
allegroai
3a3b57c15f Support mongodb authentication 2021-05-03 17:57:53 +03:00
allegroai
8ea8ad34e6 Remove collecting task output models from Models collection during migration 2021-05-03 17:57:27 +03:00
allegroai
179661a0d4 Rename default input and output models
Better handling of backwards compatibility in task models
Code cleanup
2021-05-03 17:56:50 +03:00
allegroai
3d22ca1888 Escape task.container and task.execution.model_labels fields in DB 2021-05-03 17:56:17 +03:00
allegroai
fdf6798d0c Don't unset Task's execution.queue on dequeue 2021-05-03 17:54:16 +03:00
allegroai
9d9a44b927 Add skip_empty parameter in get_configuration_names 2021-05-03 17:53:56 +03:00
allegroai
dad935e81d Remove webserver project 2021-05-03 17:53:24 +03:00
allegroai
a75534ec34 Add batch operations support 2021-05-03 17:52:54 +03:00
allegroai
eab33de97e Add bcrypt support to fixed user password 2021-05-03 17:52:25 +03:00
allegroai
29de110abb Add support for queue and model metadata 2021-05-03 17:50:25 +03:00
allegroai
2e7f418ee2 Fix Task.container backwards-compatibility
Fix sub-projects
2021-05-03 17:49:48 +03:00
allegroai
dadb996d22 Refactor es_factory to better support override host/port 2021-05-03 17:48:41 +03:00
allegroai
174f692edf Code cleanup 2021-05-03 17:48:24 +03:00
allegroai
f4d5168a20 Add Task.container support 2021-05-03 17:48:01 +03:00
allegroai
5a438e8435 Fix projects.move 2021-05-03 17:47:11 +03:00
allegroai
ce4814dc47 Add field override support in config (using "-" prefix) 2021-05-03 17:46:36 +03:00
allegroai
ef42d0265d Add multi-models support 2021-05-03 17:46:00 +03:00
allegroai
3c5195028e More sub-projects support and fixes 2021-05-03 17:44:54 +03:00
allegroai
0d5174c453 Support iterating over all task metrics in task debug images 2021-05-03 17:43:02 +03:00
allegroai
c034c1a986 Add sub-projects support 2021-05-03 17:42:10 +03:00
allegroai
1b49da8748 Revoke tests account in fixed mode, cleanup 2021-05-03 17:40:41 +03:00
allegroai
26bda01a28 Add missing errors 2021-05-03 17:39:49 +03:00
allegroai
f5008d80ad Optimize and improve tasks/models/projects.delete 2021-05-03 17:39:13 +03:00
allegroai
8b464e7ae6 Return file urls for tasks.delete/reset and models.delete 2021-05-03 17:38:09 +03:00
allegroai
78e4a58c91 Fix API enum fields and add last_iteration to range queries 2021-05-03 17:37:49 +03:00
allegroai
7a4a5eb03e Fix dropping index by name during the migration fails if the index does not exist 2021-05-03 17:36:49 +03:00
allegroai
d029d56508 Support active users in projects 2021-05-03 17:36:04 +03:00
allegroai
6411954002 Improve visibility for distributed lock hanging 2021-05-03 17:35:17 +03:00
allegroai
7f4ad0d1ca Support projects.get_hyperparam_values 2021-05-03 17:34:40 +03:00
allegroai
4cd4b2914d Add range queries
Switch from sematic_version to packaging.version in db migrations
2021-05-03 17:33:47 +03:00
allegroai
1d55710a0b Update max API version 2021-05-03 17:33:12 +03:00
allegroai
8f646043bb Allow enqueueing stopped tasks
More clearml stuff
2021-05-03 17:31:02 +03:00
allegroai
4b11a6efcd Move apiserver to clearml 2021-05-03 17:26:44 +03:00
allegroai
cb3a7c90a8 Move fileserver to clearml 2021-05-03 17:00:38 +03:00
allegroai
074842a122 Improve fileserver delete code 2021-05-03 16:58:11 +03:00
allegroai
749ff4a44f Fix Tasks.reset does not mark children's parent as deleted 2021-05-03 16:57:06 +03:00
allegroai
7d6918ecb0 Fix large plots comparison 2021-05-03 16:55:59 +03:00
allegroai
47184c2833 Fix querying by task parent 2021-05-03 16:55:03 +03:00
allegroai
6434f1028e Update docker-compose files 2021-01-14 12:37:25 +02:00
allegroai
daade08940 Update docker-compose-win10.yml
Remove deprecated docker-compose-unified.yml
2021-01-07 00:21:24 +02:00
Allegro AI
a1d289822f Update docker-compose-unified.yml
Reduce ES watermark
2021-01-06 17:46:09 +02:00
Allegro AI
1ce34f2c74 Update docker-compose-win10.yml
Reduce ES watermark
2021-01-06 17:45:27 +02:00
Allegro AI
c2dc73a71f Update docker-compose.yml
Reduce ES watermark
2021-01-06 17:44:45 +02:00
allegroai
07bb3b5df8 Update README 2021-01-06 00:32:52 +02:00
allegroai
067ef82576 Update README 2021-01-05 22:56:43 +02:00
allegroai
59fc98e0c4 Upgrade Jinja2 version (vulnerability found in older versions) 2021-01-05 20:18:09 +02:00
269 changed files with 30280 additions and 8848 deletions

1
.gitignore vendored
View File

@@ -12,7 +12,6 @@ test-reports
.pytest_cache
venv
*.noseids
build
*.egg-info
.cache
.mypy_cache

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2024 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -6,30 +6,45 @@
</br>Experiment Manager, ML-Ops and Data-Management**
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![Python versions](https://img.shields.io/badge/python-3.9-blue.svg)](https://img.shields.io/badge/python-3.9-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)
</div>
---
<div align="center">
**v0.16 Upgrade Notice**
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
</div>
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
due to Elasticsearchs usage of the Java Security Manager.
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
(which in any case **does not permit access to data within the Elasticsearch cluster**),
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
---
### ClearML Server
## ClearML Server
#### *Formerly known as Trains Server*
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
It allows multiple users to collaborate and manage their experiments.
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
The **ClearML Server** contains the following components:
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
## System design
![Alt Text](https://allegro.ai/clearml/docs/_images/ClearML_Server_Diagram.png)
![Alt Text](docs/ClearML_Server_Diagram.png)
The **ClearML Server** has two supported configurations:
- Single IP (domain) with the following open ports
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
Launch The **ClearML Server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
- Pre-built Docker Image
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
- Kubernetes
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
## Connecting ClearML to your ClearML Server
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
To have the **ClearML** client use your **ClearML Server** instead:
In order to set up the **ClearML** client to work with your **ClearML Server**:
- Run the `clearml-init` command for an interactive setup.
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
## Restarting ClearML Server
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
```bash
export TRAINS_HOST_IP=server_host_ip_here
export TRAINS_AGENT_GIT_USER=git_username_here
export TRAINS_AGENT_GIT_PASS=git_password_here
export CLEARML_HOST_IP=server_host_ip_here
export CLEARML_AGENT_GIT_USER=git_username_here
export CLEARML_AGENT_GIT_PASS=git_password_here
```
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
@@ -205,15 +219,15 @@ To upgrade your existing **ClearML Server** deployment:
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
## Community & Support
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
Additionally, you can always find us at *clearml@allegro.ai*

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2024 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -1,3 +1,8 @@
301 {
_: "moved_permanently"
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
}
400 {
_: "bad_request"
1: ["not_supported", "endpoint is not supported"]
@@ -21,6 +26,9 @@
23: ["invalid_domain_name", "malformed domain name"]
24: ["not_public_object", "object is not public"]
# Auth / Login
75: ["invalid_access_key", "access key not found"]
# Tasks
100: ["task_error", "general task error"]
101: ["invalid_task_id", "invalid task id"]
@@ -42,6 +50,12 @@
130: ["task_not_found", "task not found"]
131: ["events_not_added", "events not added"]
# Reports
150: ["operation_supported_on_reports_only", "passed task is not report"]
# Pipelines
160: ["cannot_remove_all_runs", "at least one pipeline run should be left"]
# Models
200: ["model_error", "general task error"]
201: ["invalid_model_id", "invalid model id"]
@@ -62,6 +76,15 @@
402: ["project_has_tasks", "project has associated tasks"]
403: ["project_not_found", "project not found"]
405: ["project_has_models", "project has associated models"]
406: ["project_has_datasets", "project has associated non-empty datasets"]
407: ["invalid_project_name", "invalid project name"]
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
413: ["project_has_pipelines", "project has associated pipelines with active controllers"]
414: ["public_project_exists", "Cannot create project. Public project with the same name already exists"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]
@@ -75,7 +98,7 @@
# Database
800: ["data_validation_error", "data validation error"]
801: ["expected_unique_data", "value combination already exists"]
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
# Workers
1001: ["invalid_worker_id", "invalid worker id"]
@@ -84,6 +107,11 @@
1004: ["worker_not_registered", "worker is not registered"]
1005: ["worker_stats_not_found", "worker stats not found"]
# Serving
1050: ["invalid_container_id", "invalid container id"]
1051: ["container_not_registered", "container is not registered"]
1052: ["no_containers_for_url", "no container instances found for serice url"]
1104: ["invalid_scroll_id", "Invalid scroll id"]
}
@@ -108,6 +136,11 @@
21: ["no_write_permission", "forbidden (modification not allowed)"]
}
410: {
_: "gone"
1: ["not_supported", "thus endpoint is not supported any more"]
}
500 {
_: "server_error"
0: ["general_error", "general server error"]

View File

@@ -1,10 +1,11 @@
from enum import Enum
from typing import Union, Type, Iterable
from numbers import Number
from typing import Union, Type, Iterable, Mapping
import jsonmodels.errors
import six
from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.fields import _LazyType, NotSet, EmbeddedField
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from mongoengine.base import BaseDocument
@@ -40,6 +41,34 @@ def make_default(field_cls, default_value):
return _FieldWithDefault
class OneOfEmbeddedField(EmbeddedField):
def __init__(
self,
*args,
discriminator_property: str,
discriminator_mapping: Mapping[str, type],
**kwargs,
):
self.discriminator_property = discriminator_property
self.discriminator_mapping = discriminator_mapping
model_types = tuple(set(self.discriminator_mapping.values()))
super().__init__(model_types, *args, **kwargs)
def parse_value(self, value):
"""Parse value to proper model type."""
if not isinstance(value, dict) or self.discriminator_property not in value:
return super().parse_value(value)
property_value = value.get(self.discriminator_property)
embed_type = self.discriminator_mapping.get(property_value)
if not embed_type:
raise jsonmodels.errors.ValidationError(
f"Could not find type matching discriminator property value: {property_value}"
)
return embed_type(**value)
class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default):
@@ -61,10 +90,20 @@ class ListField(fields.ListField):
item.validate()
# since there is no distinction between None and empty DictField
# this value can be used as sentinel in order to distinguish
# between not set and empty DictField
DictFieldNotSet = {}
class ScalarField(fields.BaseField):
"""String field."""
types = (str, int, float, bool)
class SafeStringField(fields.StringField):
"""String field that can also accept numbers as input"""
def parse_value(self, value):
if isinstance(value, Number):
value = str(value)
return super().parse_value(value)
class DictField(fields.BaseField):
@@ -114,9 +153,7 @@ class DictField(fields.BaseField):
if len(self.value_types) != 1:
tpl = 'Cannot decide which type to choose from "{types}".'
raise jsonmodels.errors.ValidationError(
tpl.format(
types=', '.join([t.__name__ for t in self.value_types])
)
tpl.format(types=", ".join([t.__name__ for t in self.value_types]))
)
return self.value_types[0](**value)
@@ -178,7 +215,7 @@ class EnumField(fields.StringField):
*args,
required=False,
default=None,
**kwargs
**kwargs,
):
choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator
@@ -201,7 +238,7 @@ class ActualEnumField(fields.StringField):
validators=None,
required=False,
default=None,
**kwargs
**kwargs,
):
self.__enum = enum_class
self.types = (enum_class,)
@@ -214,11 +251,11 @@ class ActualEnumField(fields.StringField):
*args,
required=required,
validators=validators,
**kwargs
**kwargs,
)
def parse_value(self, value):
if value is None and not self.required:
if value is NotSet and not self.required:
return self.get_default_value()
try:
# noinspection PyArgumentList

View File

@@ -75,11 +75,17 @@ class CreateUserResponse(Base):
class Credentials(Base):
access_key = StringField(required=True)
secret_key = StringField(required=True)
label = StringField()
class CredentialsResponse(Credentials):
secret_key = StringField()
last_used = DateTimeField(default=None)
last_used_from = StringField()
class CreateCredentialsRequest(Base):
label = StringField()
class CreateCredentialsResponse(Base):
@@ -90,6 +96,11 @@ class GetCredentialsResponse(Base):
credentials = ListField(CredentialsResponse)
class EditCredentialsRequest(Base):
access_key = StringField(required=True)
label = StringField()
class RevokeCredentialsRequest(Base):
access_key = StringField(required=True)

View File

@@ -0,0 +1,25 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
from apiserver.apimodels.base import UpdateResponse
class BatchRequest(Base):
ids: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
class BatchResponse(Base):
succeeded: Sequence[dict] = ListField([dict])
failed: Sequence[dict] = ListField([dict])
class UpdateBatchItem(UpdateResponse):
id: str = StringField()
class UpdateBatchResponse(BatchResponse):
succeeded: Sequence[UpdateBatchItem] = ListField(UpdateBatchItem)

View File

@@ -2,7 +2,7 @@ from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max
@@ -13,13 +13,33 @@ from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum
class TaskRequest(Base):
task: str = StringField(required=True)
class ModelRequest(Base):
model: str = StringField(required=True)
class HistogramRequestBase(Base):
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class MetricVariants(Base):
metric: str = StringField(required=True)
variants: Sequence[str] = ListField(items_types=str)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class GetMetricsAndVariantsRequest(Base):
task: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@@ -29,19 +49,28 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
Length(
minimum_value=1,
maximum_value=config.get(
"services.tasks.multi_task_histogram_limit", 10
"services.tasks.multi_task_histogram_limit", 100
),
)
],
)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
metric: str = StringField(default=None)
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
class LegacyMetricEventsRequest(TaskRequest):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
@@ -49,24 +78,43 @@ class DebugImagesRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class TaskMetricVariant(Base):
class VectorMetricsIterHistogramRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class GetDebugImageSampleRequest(TaskMetricVariant):
class GetVariantSampleRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class NextDebugImageSampleRequest(Base):
class GetMetricSamplesRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class NextHistorySampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
next_iteration: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class LogOrderEnum(StringEnum):
@@ -74,12 +122,40 @@ class LogOrderEnum(StringEnum):
desc = auto()
class LogEventsRequest(Base):
class TaskEventsRequestBase(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
class TaskEventsRequest(TaskEventsRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class LegacyLogEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
scroll_id: str = StringField()
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
batch_size: int = IntField()
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class IterationEvents(Base):
@@ -89,17 +165,70 @@ class IterationEvents(Base):
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
class MetricEventsResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class TaskMetricsRequest(Base):
class MultiTasksRequestBase(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
model_events: bool = BoolField(default=False)
class SingleValueMetricsRequest(MultiTasksRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class GetScalarMetricDataRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()
class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)
exclude_metrics = ListField(items_types=[str])
include_metrics = ListField(items_types=[str])

View File

@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
pass
# state = StringField(help_text="ASCII base64 encoded application state")
# callback_url_prefix = StringField()
class BasicGuestMode(Base):
@@ -31,3 +32,4 @@ class GetSupportedModesResponse(Base):
server_errors = EmbeddedField(ServerErrors)
sso = DictField([str, type(None)])
sso_providers = ListField([dict])
authenticated = BoolField(default=False)

View File

@@ -0,0 +1,24 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
class MetadataItem(Base):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
class DeleteMetadata(Base):
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)
replace_metadata = BoolField(default=False)

View File

@@ -3,7 +3,12 @@ from six import string_types
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
from apiserver.apimodels.batch import BatchRequest
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetFrameworksRequest(models.Base):
@@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,))
labels = DictField(value_types=string_types + (int,))
tags = ListField(items_types=string_types)
system_tags = ListField(items_types=string_types)
comment = fields.StringField()
@@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = DictField(value_types=[MetadataItem])
class CreateModelResponse(models.Base):
@@ -32,17 +38,62 @@ class CreateModelResponse(models.Base):
created = fields.BoolField(required=True)
class PublishModelRequest(models.Base):
class ModelRequest(models.Base):
model = fields.StringField(required=True)
class TaskRequest(models.Base):
task = fields.StringField(required=True)
class UpdateForTaskRequest(TaskRequest):
uri = fields.StringField()
iteration = fields.IntField()
override_model_id = fields.StringField()
class UpdateModelRequest(ModelRequest):
task = fields.StringField()
iteration = fields.IntField()
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class ModelTaskPublishResponse(models.Base):
id = fields.StringField(required=True)
data = fields.EmbeddedField(TaskPublishResponse)
data = fields.EmbeddedField(UpdateResponse)
class PublishModelResponse(UpdateResponse):
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
updated = fields.IntField()
class ModelsPublishManyRequest(BatchRequest):
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class DeleteMetadataRequest(DeleteMetadata):
model = fields.StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@@ -1,4 +1,11 @@
from enum import auto
from typing import Sequence
from jsonmodels import fields, models
from jsonmodels.validators import Length
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
from apiserver.utilities.stringenum import StringEnum
class Filter(models.Base):
@@ -9,3 +16,47 @@ class Filter(models.Base):
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)
class EntitiesCountRequest(models.Base):
projects = DictField()
tasks = DictField()
models = DictField()
pipelines = DictField()
datasets = DictField()
reports = DictField()
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
class EntityType(StringEnum):
task = auto()
model = auto()
class ValueMapping(models.Base):
key = ScalarField(nullable=True)
value = ScalarField(nullable=True)
class FieldMapping(models.Base):
field = fields.StringField(required=True)
name = fields.StringField()
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
class PrepareDownloadForGetAllRequest(models.Base):
entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False)
only_fields = fields.ListField(
items_types=[str], validators=[Length(1)], required=True
)
field_mappings: Sequence[FieldMapping] = fields.ListField(
items_types=[FieldMapping], validators=[Length(1)], required=True
)
class DownloadForGetAllRequest(models.Base):
prepare_id = fields.StringField(required=True)

View File

@@ -0,0 +1,21 @@
from jsonmodels import models, fields
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
class Arg(models.Base):
name = fields.StringField(required=True)
value = fields.StringField(required=True)
class DeleteRunsRequest(models.Base):
project = fields.StringField(required=True)
ids = ListField([str], required=True, validators=[Length(1)])
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
verify_watched_queue = fields.BoolField(default=False)

View File

@@ -1,15 +1,42 @@
from enum import Enum, auto
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
from apiserver.utilities.stringenum import StringEnum
class ProjectReq(models.Base):
class ProjectRequest(models.Base):
project = fields.StringField(required=True)
class MergeRequest(ProjectRequest):
destination_project = fields.StringField()
class MoveRequest(ProjectRequest):
new_location = fields.StringField()
class DeleteRequest(ProjectRequest):
force = fields.BoolField(default=False)
delete_contents = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class ProjectOrNoneRequest(models.Base):
project = fields.StringField()
include_subprojects = fields.BoolField(default=True)
class GetHyperParamReq(ProjectReq):
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
ids = fields.ListField(str)
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -18,7 +45,59 @@ class ProjectTagsRequest(TagsRequest):
projects = ListField(str)
class ProjectTaskParentsRequest(ProjectReq):
projects = ListField(str)
tasks_state = ActualEnumField(EntityVisibility)
class MultiProjectRequest(models.Base):
projects = fields.ListField(items_types=[str, type(None)])
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest):
tasks_state = ActualEnumField(EntityVisibility)
task_name = fields.StringField()
class EntityTypeEnum(StringEnum):
task = auto()
model = auto()
class ProjectUserNamesRequest(MultiProjectRequest):
entity = ActualEnumField(EntityTypeEnum, default=EntityTypeEnum.task)
class MultiProjectPagedRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
pattern = fields.StringField()
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):
key = fields.StringField(required=True)
class ProjectChildrenType(Enum):
pipeline = "pipeline"
report = "report"
dataset = "dataset"
class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False) # legacy, use allow_public instead
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)
children_tags_filter = DictField()

View File

@@ -2,7 +2,12 @@ from jsonmodels import validators
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetDefaultResp(Base):
@@ -14,12 +19,28 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
class QueueRequest(Base):
queue = StringField(required=True)
class GetByIdRequest(QueueRequest):
max_task_entries = IntField()
class GetAllRequest(Base):
max_task_entries = IntField()
search_hidden = BoolField(default=False)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
task = StringField()
class DeleteRequest(QueueRequest):
force = BoolField(default=False)
@@ -28,12 +49,21 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
class TaskRequest(QueueRequest):
task = StringField(required=True)
class RemoveTaskRequest(TaskRequest):
update_task_status = BoolField(default=False)
class AddTaskRequest(TaskRequest):
update_execution_queue = BoolField(default=True)
class MoveTaskRequest(TaskRequest):
count = IntField(default=1)
@@ -47,6 +77,7 @@ class GetMetricsRequest(Base):
from_date = FloatField(required=True, validators=validators.Min(0))
to_date = FloatField(required=True, validators=validators.Min(0))
interval = IntField(required=True, validators=validators.Min(1))
refresh = BoolField(default=False)
class QueueMetrics(Base):
@@ -58,3 +89,11 @@ class QueueMetrics(Base):
class GetMetricsResponse(Base):
queues = ListField(QueueMetrics)
class DeleteMetadataRequest(DeleteMetadata):
queue = StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
queue = StringField(required=True)

View File

@@ -0,0 +1,84 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField, ListField, BoolField, EmbeddedField, IntField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels.events import MetricVariants, HistogramRequestBase
class UpdateReportRequest(Base):
task = StringField(required=True)
name = StringField(nullable=True, validators=Length(minimum_value=3))
tags = ListField(items_types=[str])
comment = StringField()
report = StringField()
report_assets = ListField(items_types=[str])
class CreateReportRequest(Base):
name = StringField(required=True, validators=Length(minimum_value=3))
tags = ListField(items_types=[str])
comment = StringField()
report = StringField()
project = StringField()
report_assets = ListField(items_types=[str])
class PublishReportRequest(Base):
task = StringField(required=True)
message = StringField(default="")
class ArchiveReportRequest(Base):
task = StringField(required=True)
message = StringField(default="")
class ShareReportRequest(Base):
task = StringField(required=True)
share = BoolField(default=True)
class DeleteReportRequest(Base):
task = StringField(required=True)
force = BoolField(default=False)
class MoveReportRequest(Base):
task = StringField(required=True)
project = StringField()
project_name = StringField()
class EventsRequest(Base):
iters = IntField(default=1, validators=validators.Min(1))
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class PlotEventsRequest(EventsRequest):
last_iters_per_task_metric: bool = BoolField(default=True)
class ScalarMetricsIterHistogram(HistogramRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class SingleValueMetrics(Base):
pass
class GetTasksDataRequest(Base):
debug_images: EventsRequest = EmbeddedField(EventsRequest)
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
ScalarMetricsIterHistogram
)
single_value_metrics: SingleValueMetrics = EmbeddedField(SingleValueMetrics)
allow_public = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetAllRequest(Base):
allow_public = BoolField(default=True)

View File

@@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class GetConfigRequest(Base):
path = StringField()
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()

View File

@@ -0,0 +1,104 @@
from enum import Enum
from typing import Sequence
from jsonmodels.models import Base
from jsonmodels.fields import (
StringField,
EmbeddedField,
DateTimeField,
IntField,
FloatField,
BoolField,
)
from jsonmodels import validators
from jsonmodels.validators import Min
from apiserver.apimodels import ListField, JsonSerializableMixin, SafeStringField
from apiserver.apimodels import ActualEnumField
from apiserver.config_repo import config
from .workers import MachineStats
class ReferenceItem(Base):
type = StringField(
required=True,
validators=validators.Enum("app_id", "app_instance", "model", "task", "url"),
)
value = StringField(required=True)
class ServingModel(Base):
container_id = StringField(required=True)
endpoint_name = StringField(required=True)
endpoint_url = StringField() # can be not existing yet at registration time
model_name = StringField(required=True)
model_source = StringField()
model_version = StringField()
preprocess_artifact = StringField()
input_type = StringField()
input_size = SafeStringField()
tags = ListField(str)
system_tags = ListField(str)
reference: Sequence[ReferenceItem] = ListField(ReferenceItem)
class RegisterRequest(ServingModel):
timeout = IntField(
default=int(
config.get("services.serving.default_container_timeout_sec", 10 * 60)
),
validators=[Min(1)],
)
""" registration timeout in seconds (default is 10min) """
class UnregisterRequest(Base):
container_id = StringField(required=True)
class StatusReportRequest(ServingModel):
uptime_sec = IntField()
requests_num = IntField()
requests_min = FloatField()
latency_ms = IntField()
machine_stats: MachineStats = EmbeddedField(MachineStats)
class ServingContainerEntry(StatusReportRequest, JsonSerializableMixin):
key = StringField(required=True)
company_id = StringField(required=True)
ip = StringField()
register_time = DateTimeField(required=True)
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
class GetEndpointDetailsRequest(Base):
endpoint_url = StringField(required=True)
class MetricType(Enum):
requests = "requests"
requests_min = "requests_min"
latency_ms = "latency_ms"
cpu_count = "cpu_count"
gpu_count = "gpu_count"
cpu_util = "cpu_util"
gpu_util = "gpu_util"
ram_total = "ram_total"
ram_used = "ram_used"
ram_free = "ram_free"
gpu_ram_total = "gpu_ram_total"
gpu_ram_used = "gpu_ram_used"
gpu_ram_free = "gpu_ram_free"
network_rx = "network_rx"
network_tx = "network_tx"
class GetEndpointMetricsHistoryRequest(Base):
from_date = FloatField(required=True, validators=Min(0))
to_date = FloatField(required=True, validators=Min(0))
interval = IntField(required=True, validators=Min(1))
endpoint_url = StringField(required=True)
metric_type = ActualEnumField(MetricType, default=MetricType.requests)
instance_charts = BoolField(default=True)

View File

@@ -0,0 +1,60 @@
from jsonmodels.fields import StringField, BoolField, ListField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Enum
class AWSBucketSettings(Base):
bucket = StringField()
subdir = StringField()
host = StringField()
key = StringField()
secret = StringField()
token = StringField()
multipart = BoolField(default=True)
acl = StringField()
secure = BoolField(default=True)
region = StringField()
verify = BoolField(default=True)
use_credentials_chain = BoolField(default=False)
class AWSSettings(Base):
key = StringField()
secret = StringField()
region = StringField()
token = StringField()
use_credentials_chain = BoolField(default=False)
buckets = ListField(items_types=[AWSBucketSettings])
class GoogleBucketSettings(Base):
bucket = StringField()
subdir = StringField()
project = StringField()
credentials_json = StringField()
class GoogleSettings(Base):
project = StringField()
credentials_json = StringField()
buckets = ListField(items_types=[GoogleBucketSettings])
class AzureContainerSettings(Base):
account_name = StringField()
account_key = StringField()
container_name = StringField()
class AzureSettings(Base):
containers = ListField(items_types=[AzureContainerSettings])
class SetSettingsRequest(Base):
aws = EmbeddedField(AWSSettings)
google = EmbeddedField(GoogleSettings)
azure = EmbeddedField(AzureSettings)
class ResetSettingsRequest(Base):
keys = ListField([str], item_validators=[Enum("aws", "google", "azure")])

View File

@@ -1,16 +1,17 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.batch import BatchRequest, UpdateBatchItem, BatchResponse
from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
DEFAULT_ARTIFACT_MODE,
TaskModelTypes,
)
from apiserver.database.utils import get_options
@@ -41,50 +42,100 @@ class StartedResponse(UpdateResponse):
class EnqueueResponse(UpdateResponse):
queued = IntField()
queue_watched = BoolField()
class EnqueueBatchItem(UpdateBatchItem):
queued: bool = BoolField()
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
queue_watched = BoolField()
class DequeueResponse(UpdateResponse):
dequeued = IntField()
class DequeueBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
class DequeueManyResponse(BatchResponse):
succeeded: Sequence[DequeueBatchItem] = ListField(DequeueBatchItem)
class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
dequeued = DictField()
frames = DictField()
events = DictField()
model_deleted = IntField()
deleted_models = IntField()
urls = DictField()
class ResetBatchItem(UpdateBatchItem):
dequeued: bool = BoolField()
deleted_models = IntField()
urls = DictField()
class ResetManyResponse(BatchResponse):
succeeded: Sequence[ResetBatchItem] = ListField(ResetBatchItem)
class TaskRequest(models.Base):
task = StringField(required=True)
class UpdateRequest(TaskRequest):
class TaskUpdateRequest(TaskRequest):
force = BoolField(default=False)
class UpdateRequest(TaskUpdateRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
force = BoolField(default=False)
class DequeueRequest(UpdateRequest):
remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class StopRequest(UpdateRequest):
include_pipeline_steps = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
update_execution_queue = BoolField(default=True)
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class SetRequirementsRequest(TaskRequest):
requirements = DictField(required=True)
class CompletedRequest(UpdateRequest):
publish = BoolField(default=False)
class CompletedResponse(UpdateResponse):
published = IntField(default=0)
class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
class PublishResponse(UpdateResponse):
pass
class TaskData(models.Base):
"""
This is a partial description of task can be updated incrementally
@@ -104,6 +155,11 @@ class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class TaskInputModel(models.Base):
name = StringField()
model = StringField()
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
@@ -113,14 +169,15 @@ class CloneRequest(TaskRequest):
new_task_project = StringField()
new_task_hyperparams = DictField()
new_task_configuration = DictField()
new_task_container = DictField()
new_task_input_models = ListField([TaskInputModel])
execution_overrides = DictField()
validate_references = BoolField(default=False)
new_project_name = StringField()
class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArtifactId(models.Base):
@@ -130,13 +187,15 @@ class ArtifactId(models.Base):
)
class DeleteArtifactsRequest(TaskRequest):
class DeleteArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
class MultiTaskRequest(models.Base):
@@ -161,7 +220,7 @@ class ReplaceHyperparams(object):
all = "all"
class EditHyperParamsRequest(TaskRequest):
class EditHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
@@ -169,7 +228,6 @@ class EditHyperParamsRequest(TaskRequest):
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
force = BoolField(default=False)
class HyperParamKey(models.Base):
@@ -177,11 +235,10 @@ class HyperParamKey(models.Base):
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
class DeleteHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
force = BoolField(default=False)
class GetConfigurationsRequest(MultiTaskRequest):
@@ -189,7 +246,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
skip_empty = BoolField(default=True)
class Configuration(models.Base):
@@ -199,23 +256,102 @@ class Configuration(models.Base):
description = StringField()
class EditConfigurationRequest(TaskRequest):
class EditConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
force = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
class DeleteConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
include_pipeline_steps = BoolField(default=False)
class ArchiveResponse(models.Base):
archived = IntField()
class TaskBatchRequest(BatchRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
class ArchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class UnarchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
include_pipeline_steps = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest):
remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
validate_tasks = BoolField(default=False)
verify_watched_queue = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
class PublishManyRequest(TaskBatchRequest):
publish_model = BoolField(default=True)
force = BoolField(default=False)
class AddUpdateModelRequest(TaskRequest):
name = StringField(required=True)
model = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
iteration = IntField()
class ModelItemKey(models.Base):
name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1)
)
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)
class UpdateTagsRequest(BatchRequest):
add_tags = ListField([str])
remove_tags = ListField([str])

View File

@@ -4,6 +4,10 @@ from jsonmodels.models import Base
from apiserver.apimodels import DictField
class UserRequest(Base):
user = StringField(required=True)
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)

View File

@@ -12,20 +12,21 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
from apiserver.config_repo import config
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
system_tags = ListField(str)
class RegisterRequest(WorkerRequest):
timeout = make_default(
IntField, DEFAULT_TIMEOUT
)() # registration timeout in seconds (default is 10min)
timeout = IntField(
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
)
""" registration timeout in seconds (default is 10min) """
queues = ListField(six.string_types) # list of queues this worker listens to
@@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
tags = ListField(str)
system_tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):
@@ -96,12 +98,19 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
worker_pattern = StringField()
class GetAllResponse(Base):
workers = ListField(WorkerResponseEntry)
class GetCountRequest(GetAllRequest):
last_seen = IntField(default=0)
class StatsBase(Base):
worker_ids = ListField(str)

View File

@@ -2,7 +2,11 @@ from datetime import datetime
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
from apiserver.apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@@ -57,9 +61,10 @@ class AuthBLL:
api_version=str(ServiceRepo.max_endpoint_version()),
server_version=str(get_version()),
server_build=str(get_build_number()),
feature_set="basic",
)
return GetTokenResponse(token=token.decode("ascii"))
return GetTokenResponse(token=token)
@staticmethod
def create_user(request: CreateUserRequest, call: APICall = None) -> str:
@@ -144,7 +149,7 @@ class AuthBLL:
@classmethod
def create_credentials(
cls, user_id: str, company_id: str, role: str = None
cls, user_id: str, company_id: str, role: str = None, label: str = None,
) -> CredModel:
with translate_errors_context():
@@ -153,9 +158,11 @@ class AuthBLL:
if not user:
raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
cred = CredModel(
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
)
user.save()

View File

@@ -1,476 +0,0 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
EventType,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, state_)
for metric_state in state_.metrics:
metric_state.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
company_id=company_id,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
)
)
return res
def _reinit_outdated_metric_states(
self, company_id, state: DebugImageEventsScrollState
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
)
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
]
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
company_id,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
]
def _init_metric_states(
self, company_id: str, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(
self._init_metric_states_for_task, company_id=company_id
),
tasks.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
}
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
variants=[
init_variant_scroll_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
"""
if metric.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for ev in dpath.get(v, "events/hits/hits")
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]

View File

@@ -1,375 +0,0 @@
import operator
from typing import Sequence, Tuple, Optional
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugSampleHistoryState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugSampleHistoryResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class DebugSampleHistory:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_debug_image(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> DebugSampleHistoryResult:
"""
Get the debug image for next/prev variant on the current iteration
If does not exist then try getting image for the first/last variant from next/prev iteration
"""
res = DebugSampleHistoryResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
image = self._get_next_for_current_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if not image:
return res
self._fill_res_and_update_state(image=image, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
):
state.variant = image["variant"]
state.iteration = image["iter"]
res.event = image
var_state = first(s for s in state.variant_states if s.name == state.variant)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or image is found
"""
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in state.variant_states
if cmp(var_state.name, state.variant)
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"terms": {"variant": [v.name for v in variants]}},
{"term": {"iter": state.iteration}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": {"variant": "desc" if navigate_earlier else "asc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the image falls in invalid range are discarded
If no suitable image is found then None is returned
"""
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"exists": {"field": "url"}},
]
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in state.variant_states
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = state.variant_states
if not variants:
return
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
must_conditions.append({"bool": {"should": variants_conditions}})
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
es_req = {
"size": 1,
"sort": [{"iter": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_debug_image_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
) -> DebugSampleHistoryResult:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(s for s in state.variant_states if s.name == variant)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
image=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
company_id=company_id, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
for var_name, min_iter, max_iter in variant_iterations
]
def _get_variant_iterations(
self,
company_id: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
) -> Sequence[Tuple[str, int, int]]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
must = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"exists": {"field": "url"}},
]
if variants:
must.append({"terms": {"variant": variants}})
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": {
"variants": {
# all variants that sent debug images
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(
es_res, ("aggregations", "variants", "buckets")
)
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,15 @@
import base64
import zlib
from enum import Enum
from typing import Union, Sequence
from typing import Union, Sequence, Mapping, Tuple
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
class EventType(Enum):
@@ -16,7 +21,14 @@ class EventType(Enum):
all = "*"
SINGLE_SCALAR_ITERATION = -(2 ** 31)
MetricVariants = Mapping[str, Sequence[str]]
TaskCompanies = Mapping[str, Sequence[Task]]
class EventSettings:
_max_es_allowed_aggregation_buckets = 10000
@classproperty
def max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@@ -28,22 +40,31 @@ class EventSettings:
)
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
def max_es_buckets(self):
percentage = (
min(
100,
config.get(
"services.events.events_retrieval.dynamic_metrics_count_threshold",
80,
),
)
/ 100
)
return int(self._max_es_allowed_aggregation_buckets * percentage)
def get_index_name(company_id: str, event_type: str):
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
if isinstance(company_id, str):
company_id = [company_id]
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
es_index = get_index_name(company_id, event_type.value)
if not es.indices.exists(es_index):
if not es.indices.exists(index=es_index):
return True
return False
@@ -63,4 +84,83 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
def count_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.count(index=es_index, body=body, **kwargs)
def get_max_metric_and_variant_counts(
es: Elasticsearch,
company_id: Union[str, Sequence[str]],
event_type: EventType,
query: dict,
**kwargs,
) -> Tuple[int, int]:
dynamic = config.get(
"services.events.events_retrieval.dynamic_metrics_count", False
)
max_metrics_count = config.get(
"services.events.events_retrieval.max_metrics_count", 100
)
max_variants_count = config.get(
"services.events.events_retrieval.max_variants_count", 100
)
if not dynamic:
return max_metrics_count, max_variants_count
es_req: dict = {
"size": 0,
"query": query,
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
}
with translate_errors_context():
es_res = search_company_events(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)
metrics_count = nested_get(
es_res, ("aggregations", "metrics_count", "value"), max_metrics_count
)
if not metrics_count:
return max_metrics_count, max_variants_count
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"terms": {"variant": variants}},
]
}
}
if variants
else {"term": {"metric": metric}}
for metric, variants in metric_variants.items()
]
return {"bool": {"should": conditions}}
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
def uncompress_plot(event: dict):
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()

View File

@@ -4,24 +4,27 @@ from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Mapping
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
search_company_events,
check_empty_data,
MetricVariants,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION,
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@@ -34,7 +37,12 @@ class EventMetrics:
self.es = es
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -46,7 +54,12 @@ class EventMetrics:
return {}
return self._get_scalar_average_per_iter_core(
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
)
def _get_scalar_average_per_iter_core(
@@ -57,6 +70,7 @@ class EventMetrics:
samples: int,
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -64,6 +78,7 @@ class EventMetrics:
task_id=task_id,
samples=samples,
field=key.field,
metric_variants=metric_variants,
)
if not intervals:
return {}
@@ -94,57 +109,51 @@ class EventMetrics:
def compare_scalar_metrics_average_per_iter(
self,
company_id,
task_ids: Sequence[str],
companies: TaskCompanies,
samples,
key: ScalarKeyEnum,
allow_public=True,
metric_variants: MetricVariants = None,
):
"""
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name", "company", "company_origin"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
event_type = EventType.metrics_scalar
company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
companies = {
company_id: tasks
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=event_type
)
}
if not companies:
return {}
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
)
task_ids, company_ids = zip(
*(
(t.id, t.company)
for t in itertools.chain.from_iterable(companies.values())
)
)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id]
task_name = task_names[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
@@ -152,6 +161,75 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
companies = {
company_id: [t.id for t in tasks]
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
)
}
if not companies:
return {}
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(
partial(
self._get_task_single_value_metrics,
metric_variants=metric_variants,
),
companies.items(),
)
),
)
def _get_value(event: dict):
return {
field: event.get(field)
for field in ("metric", "variant", "value", "timestamp")
}
return {
task: [_get_value(e) for e in events]
for task, events in bucketize(task_events, itemgetter("task")).items()
}
def _get_task_single_value_metrics(
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
) -> Sequence[dict]:
company_id, task_ids = tasks
must = [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
es_req = {
"size": 10000,
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=EventType.metrics_scalar,
)
if not es_res["hits"]["total"]["value"]:
return []
return [hit["_source"] for hit in es_res["hits"]["hits"]]
MetricInterval = Tuple[str, str, int, int]
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@@ -197,6 +275,7 @@ class EventMetrics:
task_id: str,
samples: int,
field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
@@ -204,21 +283,31 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = self._task_conditions(task_id)
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": {"term": {"task": task_id}},
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -232,10 +321,7 @@ class EventMetrics:
},
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
@@ -256,12 +342,12 @@ class EventMetrics:
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(data, "count/value", default=0)
count = nested_get(data, ("count", "value"), default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
min_index = nested_get(data, ("min_index", "value"), default=0)
max_index = nested_get(data, ("max_index", "value"), default=min_index)
index_range = max_index - min_index + 1
interval = max(1, math.ceil(float(index_range) / samples))
max_samples = math.ceil(float(index_range) / interval)
@@ -287,33 +373,41 @@ class EventMetrics:
"""
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
}
aggs_result = self._query_aggregation_for_task_metrics(
company_id=company_id,
event_type=event_type,
aggs=aggs,
task_id=task_id,
metrics=metrics,
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
},
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
@@ -340,19 +434,20 @@ class EventMetrics:
for key, value in aggregation.items()
}
def _query_aggregation_for_task_metrics(
self,
company_id: str,
event_type: EventType,
aggs: dict,
@staticmethod
def _task_conditions(task_id: str) -> list:
return [
{"term": {"task": task_id}},
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
]
@classmethod
def _get_task_metrics_query(
cls,
task_id: str,
metrics: Sequence[Tuple[str, str]],
) -> dict:
"""
Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics
"""
must = [{"term": {"task": task_id}}]
):
must = cls._task_conditions(task_id)
if metrics:
should = [
{
@@ -367,25 +462,98 @@ class EventMetrics:
]
must.append({"bool": {"should": should}})
es_req = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": aggs,
}
return {"bool": {"must": must}}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
"""
For the requested tasks return reported metrics and variants
"""
tasks_ids = {
company: [t.id for t in tasks]
for company, tasks in companies.items()
}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
companies_res: Sequence = list(
pool.map(
partial(
self._get_multi_task_metrics,
event_type=event_type,
),
tasks_ids.items(),
)
)
return es_res.get("aggregations")
if len(companies_res) == 1:
return companies_res[0]
def get_tasks_metrics(
res = defaultdict(set)
for c_res in companies_res:
for m, vars_ in c_res.items():
res[m].update(vars_)
return {
k: list(v)
for k, v in res.items()
}
def _get_multi_task_metrics(
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
) -> Mapping[str, list]:
company_id, task_ids = company_tasks
if check_empty_data(self.es, company_id, event_type):
return {}
search_args = dict(
es=self.es,
company_id=company_id,
event_type=event_type,
)
query = QueryBuilder.terms("task", task_ids)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
}
}
}
},
}
es_res = search_company_events(
body=es_req,
**search_args,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
return {
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
for mb in aggs_result["metrics"]["buckets"]
}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
For the requested tasks return reported metrics per task
"""
if check_empty_data(self.es, company_id, event_type):
return {}
@@ -406,24 +574,23 @@ class EventMetrics:
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": self._task_conditions(task_id)}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": EventSettings.max_es_buckets,
"order": {"_key": "asc"},
}
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"), default=[])
]

View File

@@ -0,0 +1,197 @@
from typing import Optional, Tuple, Sequence, Any
import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from jwt.algorithms import get_default_algorithms
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
MetricVariants,
get_metric_variants_condition,
count_company_events,
)
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class EventsIterator:
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_key_value: Optional[Any] = None,
metric_variants: MetricVariants = None,
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
**kwargs,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
from_key_value = kwargs.pop("from_timestamp", from_key_value)
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
event_type=event_type,
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=ScalarKey.resolve(key),
)
return res
def count_task_events(
self,
event_type: EventType,
company_id: str,
task_ids: Sequence[str],
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_ids, metric_variants)
es_req = {
"query": query,
}
with translate_errors_context():
es_result = count_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_result["count"]
def _get_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
key: ScalarKey,
from_key_value: Optional[Any],
metric_variants: MetricVariants = None,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If from_key_field is not set then start either from latest or earliest.
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must([task_id], metric_variants)
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": query,
"sort": {key.field: "desc" if navigate_earlier else "asc"},
}
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context():
es_result = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": must + [{"term": {key.field: events[-1][key.field]}}]
}
},
}
es_result = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)
@staticmethod
def _get_initial_query_and_must(
task_ids: Sequence[str], metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
query = {"terms": {"task": task_ids}}
must = [query]
else:
must = [
{"terms": {"task": task_ids}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}
return query, must
class Scroll(jsonmodels.models.Base):
def get_scroll_id(self) -> str:
return jwt.encode(
self.to_struct(),
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
)
@classmethod
def from_scroll_id(cls, scroll_id: str):
try:
return cls(
**jwt.decode(
scroll_id,
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
algorithms=get_default_algorithms(),
)
)
except jwt.PyJWTError:
raise ValueError("Invalid Scroll ID")

View File

@@ -0,0 +1,455 @@
import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, BoolField, ListField
from jsonmodels.models import Base
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import (
EventType,
EventSettings,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugImageSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class VariantSampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class HistoryDebugImageIterator:
event_type = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> VariantSampleResult:
"""
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = VariantSampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if next_iteration:
event = self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event:
return res
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
@staticmethod
def _fill_res_and_update_state(
event: dict, res: VariantSampleResult, state: DebugImageSampleState
):
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@staticmethod
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in metric_variants
]
return {"bool": {"should": variants_conditions}}
metrics_conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
_get_variants_conditions(metric_variants),
]
}
}
for metric, metric_variants in metrics.items()
]
return {"bool": {"should": metrics_conditions}}
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"iter": state.iteration}},
self._get_metric_conditions(variants),
{"exists": {"field": "url"}},
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = variants
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_conditions(variants),
{"range": {"iter": {range_operator: state.iteration}}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_sample_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> VariantSampleResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = VariantSampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: DebugImageSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugImageSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugImageSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": [{"iter": "desc"}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
{"exists": {"field": "url"}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
},
}
es_res = search_company_events(body=es_req, **search_args)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}

View File

@@ -0,0 +1,316 @@
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, ListField, BoolField
from jsonmodels.models import Base
from redis.client import StrictRedis
from .event_common import (
EventType,
uncompress_plot,
EventSettings,
check_empty_data,
search_company_events,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.utilities.dicts import nested_get
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class MetricState(Base):
name: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class PlotsSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
task: str = StringField()
metric: str = StringField()
metric_states: Sequence[MetricState] = ListField([MetricState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class MetricSamplesResult(object):
scroll_id: str = None
events: list = []
min_iteration: int = None
max_iteration: int = None
class HistoryPlotsIterator:
event_type = EventType.metrics_plot
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=PlotsSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> MetricSamplesResult:
"""
Get the samples for next/prev metric on the current iteration
If does not exist then try getting sample for the first/last metric from next/prev iteration
"""
res = MetricSamplesResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if navigate_earlier:
range_operator = "lt"
order = "desc"
else:
range_operator = "gt"
order = "asc"
must_conditions = [
{"term": {"task": state.task}},
]
if state.navigate_current_metric:
must_conditions.append({"term": {"metric": state.metric}})
next_iteration_condition = {
"range": {"iter": {range_operator: state.iteration}}
}
if next_iteration or state.navigate_current_metric:
must_conditions.append(next_iteration_condition)
else:
next_metric_condition = {
"bool": {
"must": [
{"term": {"iter": state.iteration}},
{"range": {"metric": {range_operator: state.metric}}},
]
}
}
must_conditions.append(
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
)
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order=order,
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def get_samples_for_metric(
self,
company_id: str,
task: str,
metric: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> MetricSamplesResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = MetricSamplesResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: PlotsSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_metric_states(company_id=company_id, state=state_)
def validate_state(state_: PlotsSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_metric_states(company_id=company_id, state=state_)
state: PlotsSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
if not metric_state:
return res
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
]
if iteration is not None:
must_conditions.append({"range": {"iter": {"lte": iteration}}})
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order="desc",
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
return res
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
metrics = self._get_metric_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.metric_states = [
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
for metric, (min_iter, max_iter) in metrics.items()
]
def _get_metric_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[int, int]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 5000,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
},
}
},
}
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=self.event_type,
)
return {
metric_bucket["key"]: (
int(metric_bucket["first_iter"]["value"]),
int(metric_bucket["last_iter"]["value"]),
)
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}
@staticmethod
def _fill_res_and_update_state(
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
):
for event in events:
uncompress_plot(event)
state.metric = events[0]["metric"]
state.iteration = events[0]["iter"]
res.events = events
metric_state = first(
ms for ms in state.metric_states if ms.name == state.metric
)
if metric_state:
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
def _get_metric_events_for_condition(
self, company_id: str, task: str, order: str, must_conditions: Sequence
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 1,
"order": {"_key": order},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"variant": {"order": "asc"}},
"size": 100,
}
}
},
},
},
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return []
for level in ("iters", "metrics"):
level_data = aggs_result[level]["buckets"]
if not level_data:
return []
aggs_result = level_data[0]
return [
hit["_source"]
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
]

View File

@@ -1,127 +0,0 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)

View File

@@ -0,0 +1,53 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .metric_events_iterator import MetricEventsIterator, VariantState
class MetricDebugImagesIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
aggs = {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
}
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
"""If the image urls get recycled then fill the last_invalid_iteration field"""
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
if len(iters) > 1:
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
return aggs, fill_variant_state_data
def _process_event(self, event: dict) -> dict:
return event
def _get_same_variant_events_order(self) -> dict:
return {"url": {"order": "desc"}}

View File

@@ -0,0 +1,452 @@
import abc
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Callable
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
variant: str = StringField(required=True)
last_invalid_iteration: int = IntField()
class MetricState(Base):
metric: str = StringField(required=True)
variants: Sequence[VariantState] = ListField([VariantState], required=True)
timestamp: int = IntField(default=0)
class TaskScrollState(Base):
task: str = StringField(required=True)
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class MetricEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class MetricEventsResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class MetricEventsIterator:
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
self.es = es
self.event_type = event_type
self.cache_manager = RedisCacheManager(
state_class=MetricEventsScrollState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_task_events(
self,
companies: Mapping[str, str],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> MetricEventsResult:
companies = {
task_id: company_id
for task_id, company_id in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=self.event_type
)
}
if not companies:
return MetricEventsResult()
def init_state(state_: MetricEventsScrollState):
state_.tasks = self._init_task_states(companies, task_metrics)
def validate_state(state_: MetricEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
if refresh:
self._reinit_outdated_task_states(companies, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = MetricEventsResult(next_scroll_id=state.id)
specific_variants_requested = any(
variants
for t, metrics in task_metrics.items()
if metrics
for m, variants in metrics.items()
)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
companies=companies,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
specific_variants_requested=specific_variants_requested,
),
state.tasks,
)
)
return res
def _reinit_outdated_task_states(
self,
companies: Mapping[str, str],
state: MetricEventsScrollState,
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new event_type events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return {}
requested_metrics = task_metrics[task.id]
return {
stats.metric: stats.event_stats_by_type[
self.event_type.value
].last_update
for stats in metric_stats.values()
if self.event_type.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
}
update_times = {
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
}
task_metric_states = {
task_state.task: {
metric_state.metric: metric_state for metric_state in task_state.metrics
}
for task_state in state.tasks
}
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = {
m: task_metrics[task].get(m)
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
}
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
) -> TaskScrollState:
task = old_state.task
updated_state = first(uts for uts in updates if uts.task == task)
if not updated_state:
old_state.reset()
return old_state
updated_metrics = [m.metric for m in updated_state.metrics]
return TaskScrollState(
task=task,
metrics=[
*updated_state.metrics,
*(
old_metric
for old_metric in old_state.metrics
if old_metric.metric not in updated_metrics
),
],
)
state.tasks = [
merge_with_updated_task_states(task_state, updated_task_states)
for task_state in state.tasks
]
def _init_task_states(
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, companies=companies),
task_metrics.items(),
)
return [
TaskScrollState(task=task, metrics=metric_states,)
for task, metric_states in zip(task_metrics, task_metric_states)
]
@abc.abstractmethod
def _get_extra_conditions(self) -> Sequence[dict]:
pass
@abc.abstractmethod
def _get_variant_state_aggs(
self,
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
pass
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any event_type events
"""
task, metrics = task_metrics
company_id = companies[task]
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
if metrics:
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
**(
{"aggs": variant_state_aggs}
if variant_state_aggs
else {}
),
},
},
}
},
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
def init_variant_state(variant: dict):
"""
Return new variant state for the passed variant bucket
"""
state = VariantState(variant=variant["key"])
if fill_variant_state_data:
fill_variant_state_data(variant, state)
return state
return [
MetricState(
metric=metric["key"],
timestamp=nested_get(metric, ("last_event_timestamp", "value")),
variants=[
init_variant_state(variant)
for variant in nested_get(metric, ("variants", "buckets"))
],
)
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"))
]
@abc.abstractmethod
def _process_event(self, event: dict) -> dict:
pass
@abc.abstractmethod
def _get_same_variant_events_order(self) -> dict:
pass
def _get_task_metric_events(
self,
task_state: TaskScrollState,
companies: Mapping[str, str],
iter_count: int,
navigate_earlier: bool,
specific_variants_requested: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update task scroll state
"""
if not task_state.metrics:
return task_state.task, []
if task_state.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
*self._get_extra_conditions(),
]
range_condition = None
if navigate_earlier and task_state.last_min_iter is not None:
range_condition = {"lt": task_state.last_min_iter}
elif not navigate_earlier and task_state.last_max_iter is not None:
range_condition = {"gt": task_state.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
metrics_count = len(task_state.metrics)
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": self._get_same_variant_events_order(),
"size": 1,
}
}
},
}
},
}
},
}
},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=companies[task_state.task],
event_type=self.event_type,
body=es_req,
)
if "aggregations" not in es_res:
return task_state.task, []
invalid_iterations = {
(m.metric, v.variant): v.last_invalid_iteration
for m in task_state.metrics
for v in m.variants
}
allow_uninitialized = (
False
if specific_variants_requested
else config.get(
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
False,
)
)
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return allow_uninitialized
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
self._process_event(ev["_source"])
for m in nested_get(it_, ("metrics", "buckets"))
for v in nested_get(m, ("variants", "buckets"))
for ev in nested_get(v, ("events", "hits", "hits"))
if is_valid_event(ev["_source"])
]
iterations = []
for it in nested_get(es_res, ("aggregations", "iters", "buckets")):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
task_state.last_max_iter = iterations[0]["iter"]
task_state.last_min_iter = iterations[-1]["iter"]
return task_state.task, iterations

View File

@@ -0,0 +1,25 @@
from typing import Sequence
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from .event_common import EventType, uncompress_plot
from .metric_events_iterator import MetricEventsIterator
class MetricPlotsIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_plot)
def _get_extra_conditions(self) -> Sequence[dict]:
return []
def _get_variant_state_aggs(self):
return None, None
def _process_event(self, event: dict) -> dict:
uncompress_plot(event)
return event
def _get_same_variant_events_order(self) -> dict:
return {"timestamp": {"order": "desc"}}

View File

@@ -4,8 +4,10 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from typing import Any
from apiserver.utilities import extract_properties_to_lists
from apiserver.utilities.stringenum import StringEnum
from apiserver.bll.util import extract_properties_to_lists
from apiserver.config_repo import config
log = config.logger(__file__)
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
"""
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
def cast_value(self, value: Any) -> Any:
"""Cast value to appropriate type"""
return value
class TimestampKey(ScalarKey):
"""
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class IterKey(ScalarKey):
"""
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class ISOTimeKey(ScalarKey):
"""

View File

@@ -1,18 +1,237 @@
from typing import Optional, Sequence
from datetime import datetime
from typing import Callable, Tuple, Sequence, Dict, Optional
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.utils import get_company_or_none_constraint
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo.auth import Identity
from .metadata import Metadata
class ModelBLL:
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
) -> Model:
query = dict(company=company_id, id=model_id)
qs = Model.objects(**query)
if only_fields:
qs = qs.only(*only_fields)
model = qs.first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
return model
@staticmethod
def assert_exists(
company_id, model_ids, only=None, allow_public=False, return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
query = Q(id__in=ids)
q = Model.get_many(
company=company_id,
query=query,
allow_public=allow_public,
return_dicts=False,
)
if only:
q = q.only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidModelId(ids=model_ids)
if return_models:
return list(q)
@classmethod
def publish_model(
cls,
model_id: str,
company_id: str,
identity: Identity,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
user_id = identity.user
published_task = None
if model.task and publish_task_func:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, identity, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res
)
now = datetime.utcnow()
updated = model.update(
upsert=False,
ready=True,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
return updated, published_task
@classmethod
def delete_model(
cls, model_id: str, company_id: str, user_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
model_id=model_id,
only_fields=("id", "task", "project", "uri"),
)
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(models__input__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
Task._get_collection().update_many(
filter={"_id": {"$in": [t.id for t in using_tasks]}},
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
if model.task:
task = Task.objects(id=model.task).first()
if task:
now = datetime.utcnow()
if task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
"last_changed_by": user_id,
},
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
else:
task.update(
pull__models__output__model=model_id,
set__last_change=now,
set__last_changed_by=user_id,
)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=now,
last_changed_by=user_id,
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_change=now,
last_changed_by=user_id,
)
return unarchived
@classmethod
def get_model_stats(
cls, company: str, model_ids: Sequence[str],
) -> Dict[str, dict]:
if not model_ids:
return {}
result = Model.aggregate(
[
{
"$match": {
"company": {"$in": ["", company]},
"_id": {"$in": model_ids},
}
},
{
"$addFields": {
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{"$project": {"labels_count": 1}},
]
)
return {r.pop("_id"): r for r in result}
@staticmethod
def update_statistics(
company_id: str,
user_id: str,
model_id: str,
last_update: datetime = None,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
last_update = last_update or datetime.utcnow()
updates = {
"last_update": datetime.utcnow(),
"last_change": last_update,
"last_changed_by": user_id,
}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)
raw_updates = {}
if last_scalar_events is not None:
raw_updates = {}
if last_scalar_events is not None:
get_last_metric_updates(
task_id=model_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=updates,
model_events=True,
)
ret = Model.objects(id=model_id).update_one(**updates)
if ret and raw_updates:
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
return ret

View File

@@ -0,0 +1,107 @@
from typing import Sequence, Union, Mapping
from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
from apiserver.config_repo import config
log = config.logger(__file__)
class Metadata:
@staticmethod
def metadata_from_api(
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
) -> dict:
if not api_data:
return {}
if isinstance(api_data, dict):
return {
ParameterKeyEscaper.escape(k): v.to_struct()
for k, v in api_data.items()
}
return {
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
}
@classmethod
def edit_metadata(
cls,
obj: Document,
items: Sequence[MetadataItem],
replace_metadata: bool,
**more_updates,
) -> int:
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
@classmethod
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
for prefix in (
"metadata.",
"-metadata.",
):
paths = [
cls._process_path(path) if path.startswith(prefix) else path
for path in paths
]
return paths
@classmethod
def escape_query_parameters(cls, call_data: dict) -> dict:
if not call_data:
return call_data
keys = list(call_data)
call_data = {
safe_key: call_data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}
projection = GetMixin.get_projection(call_data)
if projection:
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
ordering = GetMixin.get_ordering(call_data)
if ordering:
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
return call_data

View File

@@ -1,12 +1,12 @@
from collections import defaultdict
from datetime import datetime
from enum import Enum
from operator import itemgetter
from typing import Sequence, Dict, Optional
from mongoengine import Q
from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@@ -26,6 +26,56 @@ class OrgBLL:
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
user_id: str,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
last_changed = {
"set__last_change": datetime.utcnow(),
"set__last_changed_by": user_id,
}
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags, **last_changed,
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags, **last_changed,
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags(
self,
company_id: str,
@@ -54,10 +104,10 @@ class OrgBLL:
return ret
def update_tags(
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, project, tags, system_tags)
tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)
@@ -65,34 +115,3 @@ class OrgBLL:
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
return self._task_tags if entity == Tags.Task else self._model_tags
@classmethod
def get_parent_tasks(
cls,
company_id: str,
projects: Sequence[str],
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
parents = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))

View File

@@ -5,6 +5,7 @@ from mongoengine import Q
from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -40,7 +41,9 @@ class _TagsCache:
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project=project)
query &= Q(project__in=project_ids_with_children([project]))
# else:
# query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)
@@ -103,7 +106,7 @@ class _TagsCache:
return ret
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
"""
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
@@ -119,7 +122,7 @@ class _TagsCache:
if not fields:
return
self._delete_redis_keys(company_id, projects=[project], fields=fields)
self._delete_redis_keys(company_id, projects=projects, fields=fields)
def reset_tags(self, company_id: str, projects: Sequence[str]):
self._delete_redis_keys(

View File

@@ -1 +1,3 @@
from .project_bll import ProjectBLL
from .project_queries import ProjectQueries
from .sub_projects import _ids_with_children as project_ids_with_children

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,324 @@
from collections import defaultdict
from datetime import datetime
from typing import Tuple, Set, Sequence
import attr
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.task.task_cleanup import (
TaskUrls,
schedule_for_delete,
delete_task_events_and_collect_urls,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType, TaskStatus
from .project_bll import (
ProjectBLL,
pipeline_tag,
pipelines_project_name,
dataset_tag,
datasets_project_name,
reports_tag,
)
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
@attr.s(auto_attribs=True)
class DeleteProjectResult:
deleted: int = 0
disassociated_tasks: int = 0
deleted_models: int = 0
deleted_tasks: int = 0
urls: TaskUrls = None
def _get_child_project_ids(
project_id: str,
) -> Tuple[Sequence[str], Sequence[str], Sequence[str]]:
project_ids = _ids_with_children([project_id])
pipeline_ids = list(
Project.objects(
id__in=project_ids,
system_tags__in=[pipeline_tag],
basename__ne=pipelines_project_name,
).scalar("id")
)
dataset_ids = list(
Project.objects(
id__in=project_ids,
system_tags__in=[dataset_tag],
basename__ne=datasets_project_name,
).scalar("id")
)
return project_ids, pipeline_ids, dataset_ids
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
ret = {}
if pipeline_ids:
pipelines_with_active_controllers = Task.objects(
project__in=pipeline_ids,
type=TaskType.controller,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["pipelines"] = len(pipelines_with_active_controllers)
else:
ret["pipelines"] = 0
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["datasets"] = len(datasets_with_data)
else:
ret["datasets"] = 0
project_ids = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
if project_ids:
in_project_query = Q(project__in=project_ids)
for cls in (Task, Model):
query = (
in_project_query & Q(system_tags__nin=[reports_tag])
if cls is Task
else in_project_query
)
ret[f"{cls.__name__.lower()}s"] = cls.objects(query).count()
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
query & Q(system_tags__nin=[EntityVisibility.archived.value])
).count()
ret["reports"] = Task.objects(
in_project_query & Q(system_tags__in=[reports_tag])
).count()
ret["non_archived_reports"] = Task.objects(
in_project_query
& Q(
system_tags__in=[reports_tag],
system_tags__nin=[EntityVisibility.archived.value],
)
).count()
else:
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = 0
ret[f"non_archived_{cls.__name__.lower()}s"] = 0
ret["reports"] = 0
ret["non_archived_reports"] = 0
return ret
def delete_project(
company: str,
user: str,
project_id: str,
force: bool,
delete_contents: bool,
delete_external_artifacts: bool,
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", True
)
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
if not force:
if pipeline_ids:
active_controllers = Task.objects(
project__in=pipeline_ids,
type=TaskType.controller,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if active_controllers:
raise errors.bad_request.ProjectHasPipelines(
"please archive all the controllers or use force=true",
id=project_id,
)
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if datasets_with_data:
raise errors.bad_request.ProjectHasDatasets(
"please delete all the dataset versions or use force=true",
id=project_id,
)
regular_projects = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
if regular_projects:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(
project__in=regular_projects,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if non_archived:
raise error("use force=true", id=project_id)
if not delete_contents:
disassociated = defaultdict(int)
for cls in ProjectBLL.child_classes:
disassociated[cls] = cls.objects(project__in=project_ids).update(
project=None
)
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, user=user, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, user=user, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
scheduled = schedule_for_delete(
task_id=project_id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=True,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
res = DeleteProjectResult(
deleted_tasks=deleted_tasks,
deleted_models=deleted_models,
urls=TaskUrls(
model_urls=list(model_urls),
artifact_urls=list(artifact_urls),
),
)
affected = {*project_ids, *(project.path or [])}
res.deleted = Project.objects(id__in=project_ids).delete()
return res, affected
def _delete_tasks(
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
Children tasks should be deleted in the same api call.
If any child entities are left in another projects then updated their parent task to None
"""
tasks = Task.objects(project__in=projects).only("id", "execution__artifacts")
if not tasks:
return 0, set(), set()
task_ids = list({t.id for t in tasks})
now = datetime.utcnow()
Task.objects(parent__in=task_ids, project__nin=projects).update(
parent=None,
last_change=now,
last_changed_by=user,
)
Model.objects(task__in=task_ids, project__nin=projects).update(
task=None,
last_change=now,
last_changed_by=user,
)
artifact_urls = set()
for task in tasks:
if task.execution and task.execution.artifacts:
artifact_urls.update(
{
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
)
event_urls = delete_task_events_and_collect_urls(
company=company, task_ids=task_ids, wait_for_delete=False
)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
that reference them to reference None.
"""
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set(), set()
model_ids = list({m.id for m in models})
deleted = "__DELETED__"
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": deleted}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
now = datetime.utcnow()
# update published tasks
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
"status": TaskStatus.published,
},
update={
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
"last_changed_by": user,
}
},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
# update unpublished tasks
Task.objects(
id__in=model_tasks,
project__nin=projects,
status__ne=TaskStatus.published,
).update(
pull__models__output__model__in=model_ids,
set__last_change=now,
set__last_changed_by=user,
)
model_urls = {m.uri for m in models if m.uri}
event_urls = delete_task_events_and_collect_urls(
company=company, task_ids=model_ids, model=True, wait_for_delete=False
)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -0,0 +1,407 @@
import json
from collections import OrderedDict
from datetime import datetime
from typing import (
Sequence,
Optional,
Tuple,
)
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .sub_projects import _ids_with_children
log = config.logger(__file__)
class ProjectQueries:
def __init__(self, redis=None):
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
"""
If passed projects is None means top level projects
If passed projects is empty means no project filtering
"""
if include_subprojects:
if not project_ids:
return {}
project_ids = _ids_with_children(project_ids)
if project_ids is None:
project_ids = [None]
if not project_ids:
return {}
return {"project": {"$in": project_ids}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": ["", company_id]}}
return {"company": company_id}
@classmethod
def get_aggregated_project_parameters(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"hyperparams": {"$exists": True, "$gt": {}},
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "section"))
),
"name": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "name"))
),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
ParamValues = Tuple[int, Sequence[str]]
def _get_cached_param_values(
self, key: str, last_update: datetime, allowed_delta_sec=0
) -> Optional[ParamValues]:
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving params cached values: {str(ex)}")
def get_task_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
pattern: str = None,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
page = max(0, page)
page_size = max(1, page_size)
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = "_".join(
str(part)
for part in (
"hyperparam_values",
company_id,
"_".join(project_ids),
section,
name,
allow_public,
pattern,
page,
page_size,
)
)
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
last_update=last_update,
allowed_delta_sec=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
),
)
if cached_res:
return cached_res
match_condition = {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
if pattern:
match_condition["$expr"] = {
"$regexMatch": {
"input": f"${key_path}.value",
"regex": pattern,
"options": "i",
}
}
pipeline = [
{"$match": match_condition},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def get_unique_metric_variants(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
ids: Sequence[str],
model_metrics: bool = False,
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
**({"_id": {"$in": ids}} if ids else {}),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
entity_cls = Model if model_metrics else Task
result = entity_cls.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$gt": {}},
}
},
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.k"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r.get("_id"))
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
def get_model_metadata_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
key: str,
include_subprojects: bool,
allow_public: bool = True,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
page = max(0, page)
page_size = max(1, page_size)
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
last_updated_model = (
Model.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}_{page}_{page_size}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values

View File

@@ -0,0 +1,211 @@
import itertools
from datetime import datetime
from typing import Tuple, Optional, Sequence, Mapping
from boltons.iterutils import first
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
name_separator = "/"
def _get_project_depth(project_name: str) -> int:
return len(list(filter(None, project_name.split(name_separator))))
def _validate_project_name(project_name: str, raise_if_empty=True) -> Tuple[str, str]:
"""
Remove redundant '/' characters. Ensure that the project name is not empty
Return the cleaned up project name and location
"""
name_parts = [p.strip() for p in project_name.split(name_separator) if p]
if not name_parts:
if raise_if_empty:
raise errors.bad_request.InvalidProjectName(name=project_name)
return "", ""
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(
company: str, user: str, name: str, creation_params: dict = None
) -> Optional[Project]:
"""
Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it
Return the project
"""
name, location = _validate_project_name(name, raise_if_empty=False)
if not name:
return None
project = _get_writable_project_from_name(company, name)
if project:
return project
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
user=user,
company=company,
created=now,
last_update=now,
name=name,
basename=name.split("/")[-1],
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location, creation_params=creation_params)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
return project
def _save_under_parent(project: Project, parent: Optional[Project]):
"""
Save the project under the given parent project or top level (parent=None)
Check that the project location matches the parent name
"""
location, _, _ = project.name.rpartition(name_separator)
if not parent:
if location:
raise ValueError(
f"Project location {location} does not match empty parent name"
)
project.parent = None
project.path = []
project.save()
return
if location != parent.name:
raise ValueError(
f"Project location {location} does not match parent name {parent.name}"
)
project.parent = parent.id
project.path = [*(parent.path or []), parent.id]
project.save()
def _get_writable_project_from_name(
company,
name,
_only: Optional[Sequence[str]] = ("id", "name", "path", "company", "parent"),
) -> Optional[Project]:
"""
Return a project from name. If the project not found then return None
"""
qs = Project.objects(company__in=[company, ""], name=name)
if _only:
if "company" not in _only:
_only = ["company", *_only]
qs = qs.only(*_only)
projects = list(qs)
if not projects:
return
project = first(p for p in projects if p.company == company)
if not project:
raise errors.bad_request.PublicProjectExists(name=name)
return project
ProjectsChildren = Mapping[str, Sequence[Project]]
def _get_sub_projects(
project_ids: Sequence[str],
_only: Sequence[str] = ("id", "path"),
search_hidden=True,
allowed_ids: Sequence[str] = None,
) -> ProjectsChildren:
"""
Return the list of child projects of all the levels for the parent project ids
"""
query = dict(path__in=project_ids)
if not search_hidden:
query["system_tags__nin"] = [EntityVisibility.hidden.value]
if allowed_ids:
query["id__in"] = allowed_ids
qs = Project.objects(**query)
if _only:
_only = set(_only) | {"path"}
qs = qs.only(*_only)
subprojects = list(qs)
return {
pid: [s for s in subprojects if pid in (s.path or [])] for pid in project_ids
}
def _ids_with_parents(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with all the parent projects
"""
projects = Project.objects(id__in=project_ids).only("id", "path")
parent_ids = set(itertools.chain.from_iterable(p.path for p in projects if p.path))
return list({*(p.id for p in projects), *parent_ids})
def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with the ids of all the subprojects
"""
children_ids = Project.objects(path__in=project_ids).scalar("id")
return list({*project_ids, *children_ids})
def _update_subproject_names(
project: Project,
children: Sequence[Project],
old_name: str,
update_path: bool = False,
old_path: Sequence[str] = None,
) -> int:
"""
Update sub project names when the base project name changes
Optionally update the paths
"""
updated = 0
now = datetime.utcnow()
for child in children:
child_suffix = name_separator.join(
child.name.split(name_separator)[len(old_name.split(name_separator)):]
)
updates = {
"name": name_separator.join((project.name, child_suffix)),
"last_update": now,
}
if update_path:
updates["path"] = project.path + child.path[len(old_path):]
updated += child.update(upsert=False, **updates)
return updated
def _reposition_project_with_children(
project: Project, children: Sequence[Project], parent: Project
) -> int:
new_location = parent.name if parent else None
old_name = project.name
old_path = project.path
project.name = name_separator.join(
filter(None, (new_location, project.name.split(name_separator)[-1]))
)
project.last_update = datetime.utcnow()
_save_under_parent(project, parent=parent)
moved = 1 + _update_subproject_names(
project=project,
children=children,
old_name=old_name,
update_path=True,
old_path=old_path,
)
return moved

View File

@@ -9,20 +9,35 @@ RANGE_IGNORE_VALUE = -1
class Builder:
@staticmethod
def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
def dates_range(
from_date: Optional[Union[int, float]] = None,
to_date: Optional[Union[int, float]] = None,
) -> dict:
assert (
from_date or to_date
), "range condition requires that at least one of from_date or to_date specified"
conditions = {}
if from_date:
conditions["gte"] = int(from_date)
if to_date:
conditions["lte"] = int(to_date)
return {
"range": {
"timestamp": {
"gte": int(from_date),
"lte": int(to_date),
**conditions,
"format": "epoch_second",
}
}
}
@staticmethod
def terms(field: str, values: Iterable[str]) -> dict:
def terms(field: str, values: Iterable) -> dict:
if isinstance(values, str):
assert not isinstance(values, str), "apparently 'term' should be used here"
return {"terms": {field: list(values)}}
@staticmethod
def term(field: str, value) -> dict:
return {"term": {field: value}}
@staticmethod
def normalize_range(

View File

@@ -1,10 +1,12 @@
from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from typing import Sequence, Optional, Tuple, Union, Iterable
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver import database
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.queue.queue_metrics import QueueMetrics
@@ -14,6 +16,8 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
log = config.logger(__file__)
MOVE_FIRST = "first"
MOVE_LAST = "last"
class QueueBLL(object):
@@ -32,6 +36,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[dict] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
@@ -43,13 +48,31 @@ class QueueBLL(object):
name=name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
last_update=now,
)
queue.save()
return queue
def get_by_name(
self, company_id: str, queue_name: str, only: Optional[Sequence[str]] = None,
) -> Queue:
qs = Queue.objects(name=queue_name, company=company_id)
if only:
qs = qs.only(*only)
return qs.first()
@staticmethod
def _get_task_entries_projection(max_task_entries: int) -> dict:
return dict(slice__entries=max_task_entries)
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
self,
company_id: str,
queue_id: str,
only: Optional[Sequence[str]] = None,
max_task_entries: int = None,
) -> Queue:
"""
Get queue by id
@@ -60,6 +83,8 @@ class QueueBLL(object):
qs = Queue.objects(**query)
if only:
qs = qs.only(*only)
if max_task_entries:
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
queue = qs.first()
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
@@ -110,28 +135,115 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
def _update_task_status_on_removal_from_queue(
self,
company_id: str,
user_id: str,
task_ids: Iterable[str],
queue_id: str,
reason: str
) -> Sequence[str]:
from apiserver.bll.task import ChangeStatusRequest
tasks = []
for task_id in task_ids:
try:
task = Task.get(
company=company_id,
id=task_id,
execution__queue=queue_id,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
tasks.append(task.id)
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=reason,
status_message="",
user_id=user_id,
force=True,
).execute(
enqueue_status=None,
unset__execution__queue=1,
)
except Exception as ex:
log.error(
f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}"
)
return tasks
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries and not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if not queue.entries:
queue.delete()
return []
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was deleted",
)
queue.delete()
return tasks
def get_all(
self,
company_id: str,
query_dict: dict,
query: Q = None,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
company=company_id,
parameters=query_dict,
query_dict=query_dict,
query=query,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
for worker in self.worker_bll.get_all(company_id):
if queue_id in worker.queues:
return True
return False
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
query: Q = None,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
@@ -140,7 +252,12 @@ class QueueBLL(object):
res = Queue.get_many_with_join(
company=company_id,
query_dict=query_dict,
query=query,
override_projection=projection,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
queue_workers = defaultdict(list)
@@ -153,6 +270,7 @@ class QueueBLL(object):
{
"name": w.id,
"ip": w.ip,
"key": w.key,
"task": w.task.to_struct() if w.task else None,
}
for w in queue_workers.get(item["id"], [])
@@ -171,13 +289,15 @@ class QueueBLL(object):
if any(e.task == task_id for e in queue.entries):
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
entry = Entry(added=datetime.utcnow(), task=task_id)
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
push__entries=entry, last_update=datetime.utcnow(), upsert=False
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
if not res:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
@@ -185,16 +305,22 @@ class QueueBLL(object):
return res
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
def get_next_task(
self, company_id: str, queue_id: str, task_id: str = None
) -> Optional[Entry]:
"""
Atomically pop and return the first task from the queue (or None)
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
queue = Queue.objects(
**query, **({"entries__0__task": task_id} if task_id else {})
).modify(pop__entries=-1, upsert=False)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
if not task_id or not Queue.objects(**query).first():
raise errors.bad_request.InvalidQueueId(**query)
return
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
@@ -208,7 +334,36 @@ class QueueBLL(object):
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
def clear_queue(
self,
company_id: str,
user_id: str,
queue_id: str,
):
queue = Queue.objects(company=company_id, id=queue_id).first()
if not queue:
raise errors.bad_request.InvalidQueueId(
queue=queue_id
)
if not queue.entries:
return []
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was cleared",
)
queue.update(entries=[])
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return tasks
def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
@@ -217,54 +372,168 @@ class QueueBLL(object):
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
entries_to_remove = [e for e in queue.entries if e.task == task_id]
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
if res and update_task_status:
self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids=[task_id],
queue_id=queue_id,
reason=f"Task was removed from the queue {queue_id}",
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return len(entries_to_remove) if res else 0
def reposition_task(
self,
company_id: str,
queue_id: str,
task_id: str,
pos_func: Callable[[int], int],
self, company_id: str, queue_id: str, task_id: str, move_count: Union[int, str],
) -> int:
"""
Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
def get_queue_and_task_position():
q = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
return q, next(i for i, e in enumerate(q.entries) if e.task == task_id)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
new_position = pos_func(position)
with translate_errors_context():
queue, position = get_queue_and_task_position()
if move_count == MOVE_FIRST:
new_position = 0
elif move_count == MOVE_LAST:
new_position = len(queue.entries) - 1
else:
new_position = position + move_count
if new_position == position:
return new_position
if new_position != position:
entry = queue.entries[position]
query = dict(id=queue_id, company=company_id)
updated = Queue.objects(entries__task=task_id, **query).update_one(
pull__entries=entry, last_update=datetime.utcnow()
)
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
without_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$ne": ["$$entry.task", task_id]},
}
}
task_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$eq": ["$$entry.task", task_id]},
}
}
if move_count == MOVE_FIRST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [task_entry, without_entry]}
}
}
]
elif move_count == MOVE_LAST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [without_entry, task_entry]}
}
}
]
else:
operations = [
{
"$set": {
"new_pos": {
"$add": [
{"$indexOfArray": ["$entries.task", task_id]},
move_count,
]
},
"without_entry": without_entry,
"task_entry": task_entry,
}
},
{
"$set": {
"entries": {
"$switch": {
"branches": [
{
"case": {"$lte": ["$new_pos", 0]},
"then": {
"$concatArrays": [
"$task_entry",
"$without_entry",
]
},
},
{
"case": {
"$gte": [
"$new_pos",
{"$size": "$without_entry"},
]
},
"then": {
"$concatArrays": [
"$without_entry",
"$task_entry",
]
},
},
],
"default": {
"$concatArrays": [
{"$slice": ["$without_entry", "$new_pos"]},
"$task_entry",
{
"$slice": [
"$without_entry",
"$new_pos",
{"$size": "$without_entry"},
]
},
]
},
}
}
}
},
{"$unset": ["new_pos", "without_entry", "task_entry"]},
]
return new_position
updated = Queue.objects(
id=queue_id, company=company_id, entries__task=task_id
).update_one(__raw__=operations)
if not updated:
raise errors.bad_request.FailedAddingDuringReposition(task=task_id)
return get_queue_and_task_position()[1]
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(
Queue.aggregate(
[
{
"$match": {
"company": {"$in": ["", company]},
"_id": queue_id,
}
},
{"$project": {"count": {"$size": "$entries"}}},
]
),
None,
)
if res is None:
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
return int(res.get("count"))

View File

@@ -1,8 +1,10 @@
import json
from collections import defaultdict
from datetime import datetime
from time import sleep
from typing import Sequence
import elasticsearch.helpers
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.es_factory import es_factory
@@ -11,25 +13,30 @@ from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
from apiserver.timing_context import TimingContext
from apiserver.redis_manager import redman
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
_conf = config.get("services.queues")
_queue_metrics_key_pattern = "queue_metrics_{queue}"
redis = redman.connection("apiserver")
class EsKeys:
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
class QueueMetrics:
class EsKeys:
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
def __init__(self, es: Elasticsearch):
self.es = es
@staticmethod
def _queue_metrics_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"queue_metrics_{company_id}_"
return f"queue_metrics_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():
@@ -49,7 +56,7 @@ class QueueMetrics:
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
return total_waiting_in_secs / len(entries)
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> int:
"""
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
:return: True if the write to es was successful, false otherwise
@@ -63,23 +70,22 @@ class QueueMetrics:
def make_doc(queue: Queue) -> dict:
entries = [e for e in queue.entries if e.added]
return dict(
_index=es_index,
_source={
self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id,
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
entries
),
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
},
)
return {
EsKeys.TIMESTAMP_FIELD: timestamp,
EsKeys.QUEUE_FIELD: queue.id,
EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(entries),
EsKeys.QUEUE_LENGTH_FIELD: len(entries),
}
actions = list(map(make_doc, queues))
logged = 0
for q in queues:
queue_doc = make_doc(q)
self.es.index(index=es_index, document=queue_doc)
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
redis.set(redis_key, json.dumps(queue_doc))
logged += 1
es_res = elasticsearch.helpers.bulk(self.es, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return logged
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
query = dict(company=company_id)
@@ -90,8 +96,7 @@ class QueueMetrics:
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
body=es_req,
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", body=es_req,
)
@classmethod
@@ -105,13 +110,13 @@ class QueueMetrics:
return {
"dates": {
"date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD,
"field": EsKeys.TIMESTAMP_FIELD,
"fixed_interval": f"{interval}s",
"min_doc_count": 1,
},
"aggs": {
"queues": {
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
"terms": {"field": EsKeys.QUEUE_FIELD},
"aggs": cls._get_top_waiting_agg(),
}
},
@@ -128,13 +133,13 @@ class QueueMetrics:
"top_avg_waiting": {
"top_hits": {
"sort": [
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
{EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
],
"_source": {
"includes": [
cls.EsKeys.WAITING_TIME_FIELD,
cls.EsKeys.QUEUE_LENGTH_FIELD,
EsKeys.WAITING_TIME_FIELD,
EsKeys.QUEUE_LENGTH_FIELD,
]
},
"size": 1,
@@ -149,6 +154,7 @@ class QueueMetrics:
to_date: float,
interval: int,
queue_ids: Sequence[str],
refresh: bool = False,
) -> dict:
"""
Get the company queue metrics in the specified time range.
@@ -158,7 +164,8 @@ class QueueMetrics:
In case no queue ids are specified the avg across all the
company queues is calculated for each metric
"""
# self._log_current_metrics(company, queue_ids=queue_ids)
if refresh:
self._log_current_metrics(company_id, queue_ids=queue_ids)
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
@@ -174,7 +181,7 @@ class QueueMetrics:
"aggs": self._get_dates_agg(interval),
}
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
with translate_errors_context():
res = self._search_company_metrics(company_id, es_req)
if "aggregations" not in res:
@@ -256,7 +263,52 @@ class QueueMetrics:
continue
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
queue_metrics[queue_data["key"]] = {
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
"queue_length": res[EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[EsKeys.WAITING_TIME_FIELD],
}
return queue_metrics
class MetricsRefresher:
threads = ThreadsManager()
@classproperty
def watch_interval_sec(self):
return _conf.get("metrics_refresh_interval_sec", 300)
@classmethod
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
def start(cls, queue_metrics: QueueMetrics = None):
if not cls.watch_interval_sec:
return
if not queue_metrics:
from .queue_bll import QueueBLL
queue_metrics = QueueBLL().metrics
sleep(10)
while True:
try:
for queue in Queue.objects():
timestamp = es_factory.get_timestamp_millis()
doc_time = 0
try:
redis_key = _queue_metrics_key_pattern.format(queue=queue.id)
data = redis.get(redis_key)
if data:
queue_doc = json.loads(data)
doc_time = int(queue_doc.get(EsKeys.TIMESTAMP_FIELD))
except Exception as ex:
log.exception(
f"Error reading queue metrics data for queue {queue.id}: {str(ex)}"
)
if (
not doc_time
or (timestamp - doc_time) > cls.watch_interval_sec * 1000
):
queue_metrics.log_queue_metrics_to_es(queue.company, [queue])
except Exception as ex:
log.exception(f"Failed collecting queue metrics: {str(ex)}")
sleep(60)

View File

@@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
from apiserver import database
from apiserver.timing_context import TimingContext
T = TypeVar("T")
@@ -31,24 +30,36 @@ class RedisCacheManager(Generic[T]):
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
def get_or_create_state_core(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
) -> T:
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
return state
@contextmanager
def get_or_create_state(
self,
@@ -66,12 +77,9 @@ class RedisCacheManager(Generic[T]):
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
state = self.get_or_create_state_core(
state_id=state_id, init_state=init_state, validate_state=validate_state
)
try:
yield state

View File

@@ -0,0 +1,376 @@
from datetime import datetime, timedelta, timezone
from enum import Enum, auto
from operator import attrgetter
from time import time
from typing import Optional, Sequence, Union
import attr
from boltons.iterutils import chunked_iter, bucketize
from pyhocon import ConfigTree
from apiserver.apimodels.serving import (
ServingContainerEntry,
RegisterRequest,
StatusReportRequest,
)
from apiserver.apimodels.workers import MachineStats
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.redis_manager import redman
from .stats import ServingStats
log = config.logger(__file__)
class ServingBLL:
def __init__(self, redis=None):
self.conf = config.get("services.serving", ConfigTree())
self.redis = redis or redman.connection("workers")
@staticmethod
def _get_url_key(company: str, url: str):
return f"serving_url_{company}_{url}"
@staticmethod
def _get_container_key(company: str, container_id: str) -> str:
"""Build redis key from company and container_id"""
return f"serving_container_{company}_{container_id}"
def _save_serving_container_entry(self, entry: ServingContainerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
expiration = int(time()) + entry.register_timeout
container_item = {entry.key: expiration}
self.redis.zadd(url_key, container_item)
# make sure that url set will not get stuck in redis
# indefinitely in case no more containers report to it
self.redis.expire(url_key, max(3600, entry.register_timeout))
def _get_serving_container_entry(
self, company_id: str, container_id: str
) -> Optional[ServingContainerEntry]:
"""
Get a container entry for the provided container ID.
"""
key = self._get_container_key(company_id, container_id)
data = self.redis.get(key)
if not data:
return
try:
entry = ServingContainerEntry.from_json(data)
return entry
except Exception as e:
msg = "Failed parsing container entry"
log.exception(f"{msg}: {str(e)}")
def register_serving_container(
self,
company_id: str,
request: RegisterRequest,
ip: str = "",
) -> ServingContainerEntry:
"""
Register a serving container
"""
now = datetime.now(timezone.utc)
key = self._get_container_key(company_id, request.container_id)
entry = ServingContainerEntry(
**request.to_struct(),
key=key,
company_id=company_id,
ip=ip,
register_time=now,
register_timeout=request.timeout,
last_activity_time=now,
)
self._save_serving_container_entry(entry)
return entry
def unregister_serving_container(
self,
company_id: str,
container_id: str,
) -> None:
"""
Unregister a serving container
"""
entry = self._get_serving_container_entry(company_id, container_id)
if entry:
url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
self.redis.zrem(url_key, entry.key)
key = self._get_container_key(company_id, container_id)
res = self.redis.delete(key)
if res:
return
if not self.conf.get("container_auto_unregister", True):
raise errors.bad_request.ContainerNotRegistered(container=container_id)
def container_status_report(
self,
company_id: str,
report: StatusReportRequest,
ip: str = "",
) -> None:
"""
Serving container status report
"""
container_id = report.container_id
now = datetime.now(timezone.utc)
entry = self._get_serving_container_entry(company_id, container_id)
if entry:
ip = ip or entry.ip
register_time = entry.register_time
register_timeout = entry.register_timeout
else:
if not self.conf.get("container_auto_register", True):
raise errors.bad_request.ContainerNotRegistered(container=container_id)
ip = ip
register_time = now
register_timeout = int(
self.conf.get("default_container_timeout_sec", 10 * 60)
)
key = self._get_container_key(company_id, container_id)
entry = ServingContainerEntry(
**report.to_struct(),
key=key,
company_id=company_id,
ip=ip,
register_time=register_time,
register_timeout=register_timeout,
last_activity_time=now,
)
self._save_serving_container_entry(entry)
ServingStats.log_stats_to_es(entry)
def _get_all(
self,
company_id: str,
) -> Sequence[ServingContainerEntry]:
keys = list(self.redis.scan_iter(self._get_container_key(company_id, "*")))
entries = []
for keys in chunked_iter(keys, 1000):
data = self.redis.mget(keys)
if not data:
continue
for d in data:
try:
entries.append(ServingContainerEntry.from_json(d))
except Exception as ex:
log.error(f"Failed parsing container entry {str(ex)}")
return entries
@attr.s(auto_attribs=True)
class Counter:
class AggType(Enum):
avg = auto()
max = auto()
total = auto()
count = auto()
name: str
field: str
agg_type: AggType
float_precision: int = None
_max: Union[int, float, datetime] = attr.field(init=False, default=None)
_total: Union[int, float] = attr.field(init=False, default=0)
_count: int = attr.field(init=False, default=0)
def add(self, entry: ServingContainerEntry):
value = getattr(entry, self.field, None)
if value is None:
return
self._count += 1
if self.agg_type == self.AggType.max:
self._max = value if self._max is None else max(self._max, value)
else:
self._total += value
def __call__(self):
if self.agg_type == self.AggType.count:
return self._count
if self.agg_type == self.AggType.max:
return self._max
if self.agg_type == self.AggType.total:
return self._total
if not self._count:
return None
avg = self._total / self._count
return (
round(avg, self.float_precision) if self.float_precision else round(avg)
)
def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict:
counters = [
self.Counter(
name="uptime_sec",
field="uptime_sec",
agg_type=self.Counter.AggType.max,
),
self.Counter(
name="requests",
field="requests_num",
agg_type=self.Counter.AggType.total,
),
self.Counter(
name="requests_min",
field="requests_min",
agg_type=self.Counter.AggType.avg,
float_precision=2,
),
self.Counter(
name="latency_ms",
field="latency_ms",
agg_type=self.Counter.AggType.avg,
),
self.Counter(
name="last_update",
field="last_activity_time",
agg_type=self.Counter.AggType.max,
),
]
for entry in entries:
for counter in counters:
counter.add(entry)
first_entry = entries[0]
ret = {
"endpoint": first_entry.endpoint_name,
"model": first_entry.model_name,
"url": first_entry.endpoint_url,
"instances": len(entries),
**{counter.name: counter() for counter in counters},
}
ret["last_update"] = ret.get("last_update")
return ret
def get_endpoints(self, company_id: str):
"""
Group instances by urls and return a summary for each url
Do not return data for "loading" instances that have no url
"""
entries = self._get_all(company_id)
by_url = bucketize(entries, key=attrgetter("endpoint_url"))
by_url.pop(None, None)
return [self._get_summary(url_entries) for url_entries in by_url.values()]
def _get_endpoint_entries(
self, company_id, endpoint_url: Union[str, None]
) -> Sequence[ServingContainerEntry]:
url_key = self._get_url_key(company_id, endpoint_url)
timestamp = int(time())
self.redis.zremrangebyscore(url_key, min=0, max=timestamp)
container_keys = {key.decode() for key in self.redis.zrange(url_key, 0, -1)}
if not container_keys:
return []
entries = []
found_keys = set()
data = self.redis.mget(container_keys) or []
for d in data:
try:
entry = ServingContainerEntry.from_json(d)
if entry.endpoint_url == endpoint_url:
entries.append(entry)
found_keys.add(entry.key)
except Exception as ex:
log.error(f"Failed parsing container entry {str(ex)}")
missing_keys = container_keys - found_keys
if missing_keys:
self.redis.zrem(url_key, *missing_keys)
return entries
def get_loading_instances(self, company_id: str):
entries = self._get_endpoint_entries(company_id, None)
return [
{
"id": entry.container_id,
"endpoint": entry.endpoint_name,
"url": entry.endpoint_url,
"model": entry.model_name,
"model_source": entry.model_source,
"model_version": entry.model_version,
"preprocess_artifact": entry.preprocess_artifact,
"input_type": entry.input_type,
"input_size": entry.input_size,
"uptime_sec": entry.uptime_sec,
"age_sec": int((datetime.now(timezone.utc) - entry.register_time).total_seconds()),
"last_update": entry.last_activity_time,
}
for entry in entries
]
def get_endpoint_details(self, company_id, endpoint_url: str) -> dict:
entries = self._get_endpoint_entries(company_id, endpoint_url)
if not entries:
raise errors.bad_request.NoContainersForUrl(url=endpoint_url)
instances = []
entry: ServingContainerEntry
for entry in entries:
instances.append(
{
"endpoint": entry.endpoint_name,
"model": entry.model_name,
"url": entry.endpoint_url,
}
)
def get_machine_stats_data(machine_stats: MachineStats) -> dict:
ret = {"cpu_count": 0, "gpu_count": 0}
if not machine_stats:
return ret
for value, field in (
(machine_stats.cpu_usage, "cpu_count"),
(machine_stats.gpu_usage, "gpu_count"),
):
if value is None:
continue
ret[field] = len(value) if isinstance(value, (list, tuple)) else 1
return ret
first_entry = entries[0]
return {
"endpoint": first_entry.endpoint_name,
"model": first_entry.model_name,
"url": first_entry.endpoint_url,
"preprocess_artifact": first_entry.preprocess_artifact,
"input_type": first_entry.input_type,
"input_size": first_entry.input_size,
"model_source": first_entry.model_source,
"model_version": first_entry.model_version,
"uptime_sec": max(e.uptime_sec for e in entries),
"last_update": max(e.last_activity_time for e in entries),
"instances": [
{
"id": entry.container_id,
"uptime_sec": entry.uptime_sec,
"requests": entry.requests_num,
"requests_min": entry.requests_min,
"latency_ms": entry.latency_ms,
"last_update": entry.last_activity_time,
"reference": [ref.to_struct() for ref in entry.reference]
if isinstance(entry.reference, list)
else entry.reference,
**get_machine_stats_data(entry.machine_stats),
}
for entry in entries
],
}

View File

@@ -0,0 +1,335 @@
from collections import defaultdict
from datetime import datetime, timezone
from enum import Enum
from typing import Tuple, Optional, Sequence
from elasticsearch import Elasticsearch
from apiserver.apimodels.serving import (
ServingContainerEntry,
GetEndpointMetricsHistoryRequest,
MetricType,
)
from apiserver.apierrors import errors
from apiserver.utilities.dicts import nested_get
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.es_factory import es_factory
class _AggregationType(Enum):
avg = "avg"
sum = "sum"
class ServingStats:
min_chart_interval = config.get("services.serving.min_chart_interval_sec", 40)
es: Elasticsearch = es_factory.connect("workers")
@classmethod
def _serving_stats_prefix(cls, company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"serving_stats_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():
"""Get the index name suffix for storing current month data"""
return datetime.now(timezone.utc).strftime("%Y-%m")
@staticmethod
def _get_average_value(value) -> Tuple[Optional[float], Optional[int]]:
if value is None:
return None, None
if isinstance(value, (list, tuple)):
count = len(value)
if not count:
return None, None
return sum(value) / count, count
return value, 1
@classmethod
def log_stats_to_es(
cls,
entry: ServingContainerEntry,
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: The amount of logged documents
"""
company_id = entry.company_id
es_index = (
f"{cls._serving_stats_prefix(company_id)}" f"{cls._get_es_index_suffix()}"
)
entry_data = entry.to_struct()
doc = {
"timestamp": es_factory.get_timestamp_millis(),
**{
field: entry_data.get(field)
for field in (
"container_id",
"company_id",
"endpoint_url",
"requests_num",
"requests_min",
"uptime_sec",
"latency_ms",
)
},
}
stats = entry_data.get("machine_stats")
if stats:
for category in ("cpu", "gpu"):
usage, num = cls._get_average_value(stats.get(f"{category}_usage"))
doc.update({f"{category}_usage": usage, f"{category}_num": num})
for category in ("memory", "gpu_memory"):
free, _ = cls._get_average_value(stats.get(f"{category}_free"))
used, _ = cls._get_average_value(stats.get(f"{category}_used"))
doc.update(
{
f"{category}_free": free,
f"{category}_used": used,
f"{category}_total": round((free or 0) + (used or 0), 3),
}
)
doc.update(
{
field: stats.get(field)
for field in ("disk_free_home", "network_rx", "network_tx")
}
)
cls.es.index(index=es_index, document=doc)
return 1
@staticmethod
def round_series(values: Sequence, koeff) -> list:
return [round(v * koeff, 2) if v else 0 for v in values]
_mb_to_gb = 1 / 1024
agg_fields = {
MetricType.requests: (
"requests_num",
"Number of Requests",
_AggregationType.sum,
None,
),
MetricType.requests_min: (
"requests_min",
"Requests per Minute",
_AggregationType.sum,
None,
),
MetricType.latency_ms: (
"latency_ms",
"Average Latency (ms)",
_AggregationType.avg,
None,
),
MetricType.cpu_count: ("cpu_num", "CPU Count", _AggregationType.sum, None),
MetricType.gpu_count: ("gpu_num", "GPU Count", _AggregationType.sum, None),
MetricType.cpu_util: (
"cpu_usage",
"Average CPU Load (%)",
_AggregationType.avg,
None,
),
MetricType.gpu_util: (
"gpu_usage",
"Average GPU Utilization (%)",
_AggregationType.avg,
None,
),
MetricType.ram_total: (
"memory_total",
"RAM Total (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.ram_used: (
"memory_used",
"RAM Used (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.ram_free: (
"memory_free",
"RAM Free (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_total: (
"gpu_memory_total",
"GPU RAM Total (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_used: (
"gpu_memory_used",
"GPU RAM Used (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_free: (
"gpu_memory_free",
"GPU RAM Free (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.network_rx: (
"network_rx",
"Network Throughput RX (MBps)",
_AggregationType.sum,
None,
),
MetricType.network_tx: (
"network_tx",
"Network Throughput TX (MBps)",
_AggregationType.sum,
None,
),
}
@classmethod
def get_endpoint_metrics(
cls,
company_id: str,
metrics_request: GetEndpointMetricsHistoryRequest,
) -> dict:
from_date = metrics_request.from_date
to_date = metrics_request.to_date
if from_date >= to_date:
raise errors.bad_request.FieldsValueError(
"from_date must be less than to_date"
)
metric_type = metrics_request.metric_type
agg_data = cls.agg_fields.get(metric_type)
if not agg_data:
raise NotImplemented(f"Charts for {metric_type} not implemented")
agg_field, title, agg_type, multiplier = agg_data
if agg_type == _AggregationType.sum:
instance_sum_type = "sum_bucket"
else:
instance_sum_type = "avg_bucket"
interval = max(metrics_request.interval, cls.min_chart_interval)
endpoint_url = metrics_request.endpoint_url
hist_ret = {
"computed_interval": interval,
"total": {
"title": title,
"dates": [],
"values": [],
},
"instances": {},
}
must_conditions = [
QueryBuilder.term("company_id", company_id),
QueryBuilder.term("endpoint_url", endpoint_url),
QueryBuilder.dates_range(from_date, to_date),
]
query = {"bool": {"must": must_conditions}}
es_index = f"{cls._serving_stats_prefix(company_id)}*"
res = cls.es.search(
index=es_index,
size=0,
query=query,
aggs={"instances": {"terms": {"field": "container_id"}}},
)
instance_buckets = nested_get(res, ("aggregations", "instances", "buckets"))
if not instance_buckets:
return hist_ret
instance_keys = {ib["key"] for ib in instance_buckets}
must_conditions.append(QueryBuilder.terms("container_id", instance_keys))
query = {"bool": {"must": must_conditions}}
sample_func = "avg" if metric_type != MetricType.requests else "max"
aggs = {
"instances": {
"terms": {
"field": "container_id",
"size": max(len(instance_keys), 10),
},
"aggs": {
"sample": {sample_func: {"field": agg_field}},
},
},
"total_instances": {
instance_sum_type: {
"gap_policy": "insert_zeros",
"buckets_path": "instances>sample",
}
},
}
aggs = {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
},
},
"aggs": aggs,
}
}
filter_path = None
if not metrics_request.instance_charts:
filter_path = "aggregations.dates.buckets.total_instances"
data = cls.es.search(
index=es_index,
size=0,
query=query,
aggs=aggs,
filter_path=filter_path,
)
agg_res = data.get("aggregations")
if not agg_res:
return hist_ret
dates_ = []
total = []
instances = defaultdict(list)
# remove last interval if it's incomplete. Allow 10% tolerance
last_valid_timestamp = (to_date - 0.9 * interval) * 1000
for point in agg_res["dates"]["buckets"]:
date_ = point["key"]
if date_ > last_valid_timestamp:
break
dates_.append(date_)
total.append(nested_get(point, ("total_instances", "value"), 0))
if metrics_request.instance_charts:
found_keys = set()
for instance in nested_get(point, ("instances", "buckets"), []):
instances[instance["key"]].append(
nested_get(instance, ("sample", "value"), 0)
)
found_keys.add(instance["key"])
for missing_key in instance_keys - found_keys:
instances[missing_key].append(0)
koeff = multiplier if multiplier else 1.0
hist_ret["total"]["dates"] = dates_
hist_ret["total"]["values"] = cls.round_series(total, koeff)
hist_ret["instances"] = {
key: {
"title": key,
"dates": dates_,
"values": cls.round_series(values, koeff),
}
for key, values in sorted(instances.items(), key=lambda p: p[0])
}
return hist_ret

View File

@@ -1,6 +1,6 @@
from datetime import datetime
import operator
from threading import Thread, Lock
from threading import Lock
from time import sleep
import attr
@@ -9,76 +9,83 @@ import psutil
from apiserver.utilities.threads_manager import ThreadsManager
class ResourceMonitor(Thread):
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
stat_threads = ThreadsManager("Statistics")
@classmethod
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
def __init__(self, sample_interval_sec=5):
super(ResourceMonitor, self).__init__(daemon=True)
self.sample_interval_sec = sample_interval_sec
self._lock = Lock()
self._clear()
def _clear(self):
sample = self._get_sample()
self._avg = sample
self._min = sample
self._max = sample
self._clear_time = datetime.utcnow()
self._count = 1
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
@classmethod
def _get_sample(cls) -> Sample:
return cls.Sample(
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
@classmethod
def get_current_sample(cls) -> "Sample":
return cls(
cpu_usage=psutil.cpu_percent(),
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
)
def run(self):
while not ThreadsManager.terminating:
sleep(self.sample_interval_sec)
sample = self._get_sample()
class ResourceMonitor:
class Accumulator:
def __init__(self):
sample = Sample.get_current_sample()
self.avg = sample
self.min = sample
self.max = sample
self.time = datetime.utcnow()
self.count = 1
with self._lock:
self._min = self._min.min(sample)
self._max = self._max.max(sample)
self._avg = self._avg.avg(sample, self._count)
self._count += 1
def add_sample(self, sample: Sample):
self.min = self.min.min(sample)
self.max = self.max.max(sample)
self.avg = self.avg.avg(sample, self.count)
self.count += 1
def get_stats(self) -> dict:
sample_interval_sec = 5
_lock = Lock()
accumulator = Accumulator()
@classmethod
@stat_threads.register("resource_monitor", daemon=True)
def start(cls):
while True:
sleep(cls.sample_interval_sec)
sample = Sample.get_current_sample()
with cls._lock:
cls.accumulator.add_sample(sample)
@classmethod
def get_stats(cls) -> dict:
""" Returns current resource statistics and clears internal resource statistics """
with self._lock:
min_ = attr.asdict(self._min)
max_ = attr.asdict(self._max)
avg = attr.asdict(self._avg)
interval = datetime.utcnow() - self._clear_time
self._clear()
with cls._lock:
min_ = attr.asdict(cls.accumulator.min)
max_ = attr.asdict(cls.accumulator.max)
avg = attr.asdict(cls.accumulator.avg)
interval = datetime.utcnow() - cls.accumulator.time
cls.accumulator = cls.Accumulator()
return {
"interval_sec": interval.total_seconds(),

View File

@@ -8,8 +8,7 @@ from typing import Sequence, Optional
import dpath
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter, Retry
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.bll.util import get_server_uuid
@@ -19,11 +18,10 @@ from apiserver.config.info import get_deployment_type
from apiserver.database.model import Company, User
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import dumps
from apiserver.utilities.threads_manager import ThreadsManager
from apiserver.version import __version__ as current_version
from .resource_monitor import ResourceMonitor
from .resource_monitor import ResourceMonitor, stat_threads
log = config.logger(__file__)
@@ -31,21 +29,23 @@ worker_bll = WorkerBLL()
class StatisticsReporter:
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
send_queue = queue.Queue()
supported = config.get("apiserver.statistics.supported", True)
@classmethod
def start(cls):
if not cls.supported:
return
ResourceMonitor.start()
cls.start_sender()
cls.start_reporter()
@classmethod
@threads.register("reporter", daemon=True)
@stat_threads.register("reporter", daemon=True)
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
Note: in trains we usually have only a single company
Note: in clearml we usually have only a single company
"""
if not cls.supported:
return
@@ -54,7 +54,7 @@ class StatisticsReporter:
hours=config.get("apiserver.statistics.report_interval_hours", 24)
)
sleep(report_interval.total_seconds())
while not ThreadsManager.terminating:
while True:
try:
for company in Company.objects(
defaults__stats_option__enabled=True
@@ -68,7 +68,7 @@ class StatisticsReporter:
sleep(report_interval.total_seconds())
@classmethod
@threads.register("sender", daemon=True)
@stat_threads.register("sender", daemon=True)
def start_sender(cls):
if not cls.supported:
return
@@ -85,7 +85,7 @@ class StatisticsReporter:
WarningFilter.attach()
while not ThreadsManager.terminating:
while True:
try:
report = cls.send_queue.get()
@@ -111,7 +111,7 @@ class StatisticsReporter:
"uuid": get_server_uuid(),
"queues": {"count": Queue.objects(company=company_id).count()},
"users": {"count": User.objects(company=company_id).count()},
"resources": cls.threads.resource_monitor.get_stats(),
"resources": ResourceMonitor.get_stats(),
"experiments": next(
iter(cls._get_experiments_stats(company_id).values()), {}
),
@@ -162,7 +162,7 @@ class StatisticsReporter:
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
names = {"cpu": "num_cores"}
return {
names[c["key"]]: safe_get(c, "count/value")
names[c["key"]]: nested_get(c, ("count", "value"))
for c in categories
if c["key"] in names
}
@@ -175,21 +175,21 @@ class StatisticsReporter:
}
return {
names[m["key"]]: {
"min": safe_get(m, "min/value"),
"max": safe_get(m, "max/value"),
"avg": safe_get(m, "avg/value"),
"min": nested_get(m, ("min", "value")),
"max": nested_get(m, ("max", "value")),
"avg": nested_get(m, ("avg", "value")),
}
for m in metrics
if m["key"] in names
}
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {
key: {
"interval_sec": agent_resource_threshold_sec,
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
**_get_cardinality_fields(nested_get(b, ("categories", "buckets"), [])),
**_get_metric_fields(nested_get(b, ("metrics", "buckets"), [])),
}
}
for b in buckets
@@ -227,7 +227,7 @@ class StatisticsReporter:
},
}
res = cls._run_worker_stats_query(company_id, es_req)
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
for b in buckets
@@ -254,6 +254,14 @@ class StatisticsReporter:
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$project": {
"last_worker": 1,
"last_update": 1,
"started": 1,
"last_iteration": 1,
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,

View File

@@ -0,0 +1,273 @@
import json
import os
import tempfile
from copy import copy
from datetime import datetime
from typing import Optional, Sequence
import attr
from boltons.cacheutils import cachedproperty
from clearml.backend_config.bucket_config import (
S3BucketConfigurations,
AzureContainerConfigurations,
GSBucketConfigurations,
AzureContainerConfig,
GSBucketConfig,
S3BucketConfig,
)
from apiserver.apierrors import errors
from apiserver.apimodels.storage import SetSettingsRequest
from apiserver.config_repo import config
from apiserver.database.model.storage_settings import (
StorageSettings,
GoogleBucketSettings,
AWSSettings,
AzureStorageSettings,
GoogleStorageSettings,
)
from apiserver.database.utils import id as db_id
log = config.logger(__file__)
class StorageBLL:
default_aws_configs: S3BucketConfigurations = None
conf = config.get("services.storage_credentials")
@cachedproperty
def _default_aws_configs(self) -> S3BucketConfigurations:
return S3BucketConfigurations.from_config(self.conf.get("aws.s3"))
@cachedproperty
def _default_azure_configs(self) -> AzureContainerConfigurations:
return AzureContainerConfigurations.from_config(self.conf.get("azure.storage"))
@cachedproperty
def _default_gs_configs(self) -> GSBucketConfigurations:
return GSBucketConfigurations.from_config(self.conf.get("google.storage"))
def get_azure_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
) -> AzureContainerConfigurations:
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("azure").first()
)
if not db_settings or not db_settings.azure:
return copy(self._default_azure_configs)
azure = db_settings.azure
return AzureContainerConfigurations(
container_configs=[
AzureContainerConfig(**entry.to_proper_dict())
for entry in (azure.containers or [])
]
)
def get_gs_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
json_string: bool = False,
) -> GSBucketConfigurations:
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("google").first()
)
if not db_settings or not db_settings.google:
if not json_string:
return copy(self._default_gs_configs)
if self._default_gs_configs._buckets:
buckets = [
attr.evolve(
b,
credentials_json=self._assure_json_string(b.credentials_json),
)
for b in self._default_gs_configs._buckets
]
else:
buckets = self._default_gs_configs._buckets
return GSBucketConfigurations(
buckets=buckets,
default_project=self._default_gs_configs._default_project,
default_credentials=self._assure_json_string(
self._default_gs_configs._default_credentials
),
)
def get_bucket_config(bc: GoogleBucketSettings) -> GSBucketConfig:
data = bc.to_proper_dict()
if not json_string and bc.credentials_json:
data["credentials_json"] = self._assure_json_file(bc.credentials_json)
return GSBucketConfig(**data)
google = db_settings.google
buckets_configs = [get_bucket_config(b) for b in (google.buckets or [])]
return GSBucketConfigurations(
buckets=buckets_configs,
default_project=google.project,
default_credentials=google.credentials_json
if json_string
else self._assure_json_file(google.credentials_json),
)
def get_aws_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
) -> S3BucketConfigurations:
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("aws").first()
)
if not db_settings or not db_settings.aws:
return copy(self._default_aws_configs)
aws = db_settings.aws
buckets_configs = S3BucketConfig.from_list(
[b.to_proper_dict() for b in (aws.buckets or [])]
)
return S3BucketConfigurations(
buckets=buckets_configs,
default_key=aws.key,
default_secret=aws.secret,
default_region=aws.region,
default_use_credentials_chain=aws.use_credentials_chain,
default_token=aws.token,
default_extra_args={},
)
def _assure_json_file(self, name_or_content: str) -> str:
if not name_or_content:
return name_or_content
if name_or_content.endswith(".json") or os.path.exists(name_or_content):
return name_or_content
try:
json.loads(name_or_content)
except Exception:
return name_or_content
with tempfile.NamedTemporaryFile(
mode="wt", delete=False, suffix=".json"
) as tmp:
tmp.write(name_or_content)
return tmp.name
def _assure_json_string(self, name_or_content: str) -> Optional[str]:
if not name_or_content:
return name_or_content
try:
json.loads(name_or_content)
return name_or_content
except Exception:
pass
try:
with open(name_or_content) as fp:
return fp.read()
except Exception:
return None
def get_company_settings(self, company_id: str) -> dict:
db_settings = StorageSettings.objects(company=company_id).first()
aws = self.get_aws_settings_for_company(company_id, db_settings, query_db=False)
aws_dict = {
"key": aws._default_key,
"secret": aws._default_secret,
"token": aws._default_token,
"region": aws._default_region,
"use_credentials_chain": aws._default_use_credentials_chain,
"buckets": [attr.asdict(b) for b in aws._buckets],
}
gs = self.get_gs_settings_for_company(
company_id, db_settings, query_db=False, json_string=True
)
gs_dict = {
"project": gs._default_project,
"credentials_json": gs._default_credentials,
"buckets": [attr.asdict(b) for b in gs._buckets],
}
azure = self.get_azure_settings_for_company(company_id, db_settings)
azure_dict = {
"containers": [attr.asdict(ac) for ac in azure._container_configs],
}
return {
"aws": aws_dict,
"google": gs_dict,
"azure": azure_dict,
"last_update": db_settings.last_update if db_settings else None,
}
def set_company_settings(
self, company_id: str, settings: SetSettingsRequest
) -> int:
update_dict = {}
if settings.aws:
update_dict["aws"] = {
**{
k: v
for k, v in settings.aws.to_struct().items()
if k in AWSSettings.get_fields()
}
}
if settings.azure:
update_dict["azure"] = {
**{
k: v
for k, v in settings.azure.to_struct().items()
if k in AzureStorageSettings.get_fields()
}
}
if settings.google:
update_dict["google"] = {
**{
k: v
for k, v in settings.google.to_struct().items()
if k in GoogleStorageSettings.get_fields()
}
}
cred_json = update_dict["google"].get("credentials_json")
if cred_json:
try:
json.loads(cred_json)
except Exception as ex:
raise errors.bad_request.ValidationError(
f"Invalid json credentials: {str(ex)}"
)
if not update_dict:
raise errors.bad_request.ValidationError("No settings were provided")
settings = StorageSettings.objects(company=company_id).only("id").first()
settings_id = settings.id if settings else db_id()
return StorageSettings.objects(id=settings_id).update(
upsert=True,
id=settings_id,
company=company_id,
last_update=datetime.utcnow(),
**update_dict,
)
def reset_company_settings(self, company_id: str, keys: Sequence[str]) -> int:
return StorageSettings.objects(company=company_id).update(
last_update=datetime.utcnow(), **{f"unset__{k}": 1 for k in keys}
)

View File

@@ -1,7 +1,5 @@
from .task_bll import TaskBLL
from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
split_by,
)

View File

@@ -1,11 +1,11 @@
from hashlib import md5
from operator import itemgetter
from typing import Sequence
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.timing_context import TimingContext
from apiserver.database.utils import hash_field_name
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
Calculate id from 'key' and 'mode' fields
Return hash on on the id so that it will not contain mongo illegal characters
"""
key_hash: str = md5(artifact["key"].encode()).hexdigest()
key_hash: str = hash_field_name(artifact["key"])
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
return f"{key_hash}_{mode}"
@@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
nested_set(
fields,
artifacts_field,
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
value=sorted(artifacts.values(), key=itemgetter("key")),
)
@@ -49,49 +49,45 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
identity: Identity,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
update_cmds = {
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, update_cmds=update_cmds)
update_cmds = {
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
identity: Identity,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -15,7 +15,7 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.service_repo.auth import Identity
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -32,7 +32,10 @@ class HyperParams:
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -64,78 +67,84 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only
)
return update_task(
task,
user_id=identity.user,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@classmethod
def edit_params(
cls,
company_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
force: bool,
) -> int:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only
)
return update_task(
task,
user_id=identity.user,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
@@ -160,7 +169,10 @@ class HyperParams:
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -175,71 +187,76 @@ class HyperParams:
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
) -> Dict[str, list]:
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
pipeline = [
{
"$match": {
"company": {"$in": ["", company_id]},
"_id": {"$in": task_ids},
}
for task in tasks
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
*([skip_empty_condition] if skip_empty else []),
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
cls,
company_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -1,7 +1,7 @@
from datetime import timedelta, datetime
from time import sleep
from apiserver.bll.task import update_project_time
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.task.task import TaskStatus, Task
from apiserver.utilities.threads_manager import ThreadsManager
@@ -39,7 +39,7 @@ class NonResponsiveTasksWatchdog:
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start(cls):
sleep(cls.settings.watch_interval_sec)
while not ThreadsManager.terminating:
while True:
watch_interval = cls.settings.watch_interval_sec
if cls.settings.enabled:
try:
@@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
status_changed=now,
last_update=now,
last_change=now,
last_changed_by="__apiserver__",
)
if updated:
project_ids.add(task.project)

View File

@@ -1,11 +1,10 @@
import itertools
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Optional
import dpath
from apiserver.apierrors import errors
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
return list(
itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
)
return [
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
(("execution", "model_desc"), "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
legacy_params = nested_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
not fields.get(new_params_field)
and previous_task
and previous_task[new_params_field]
):
@@ -117,21 +118,34 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
nested_set(fields, new_path, new_param)
nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
ParameterKeyEscaper.escape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
def ensure_non_empty(k: str, desc: str) -> str:
if not k:
raise errors.bad_request.ValidationError(
f"Empty {desc} name is not allowed"
)
return k
params = fields.get("hyperparams")
if params:
escaped_params = {
ParameterKeyEscaper.escape(ensure_non_empty(key, "section")): {
ParameterKeyEscaper.escape(ensure_non_empty(k, "parameter")): v
for k, v in value.items()
}
dpath.set(fields, param_field, escaped_params)
for key, value in params.items()
}
fields["hyperparams"] = escaped_params
params = fields.get("configuration")
if params:
escaped_params = {
ParameterKeyEscaper.escape(ensure_non_empty(key, "configuration")): value
for key, value in params.items()
}
fields["configuration"] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False):
@@ -140,7 +154,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
@@ -150,18 +164,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
fields[param_field] = unescaped_params
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
("hyperparams", ("execution", "parameters"), True),
("configuration", ("execution", "model_desc"), False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
fields.get(new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
nested_set(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
@@ -174,7 +188,7 @@ def _process_path(path: str):
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
if len(parts) < 2 or len(parts) > 4:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
@@ -184,7 +198,8 @@ def _process_path(path: str):
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
("execution.model_desc", "configuration"),
("execution.docker_cmd", "container"),
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]

View File

@@ -1,17 +1,18 @@
from collections import OrderedDict
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
from typing import Collection, Sequence, Tuple, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.apierrors import errors, APIError
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -21,21 +22,31 @@ from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
external_task_types,
ModelItem,
Models,
DEFAULT_ARTIFACT_MODE,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.database.model.queue import Queue
from apiserver.database.utils import (
get_company_or_none_constraint,
id as create_id,
)
from apiserver.es_factory import es_factory
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.utilities.dicts import nested_set
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
from .utils import (
ChangeStatusRequest,
deleted_prefix,
get_last_metric_updates,
)
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -44,48 +55,17 @@ project_bll = ProjectBLL()
class TaskBLL:
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
with TimingContext("mongo", "task_with_access"):
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def __init__(self, events_es=None, redis=None):
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
@@ -94,15 +74,14 @@ class TaskBLL:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@@ -117,7 +96,7 @@ class TaskBLL:
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
with translate_errors_context():
ids = set(task_ids)
q = Task.get_many(
company=company_id,
@@ -137,33 +116,34 @@ class TaskBLL:
return list(q)
@staticmethod
def create(call: APICall, fields: dict):
identity = call.identity
def create(company: str, user: str, fields: dict):
now = datetime.utcnow()
return Task(
id=create_id(),
user=identity.user,
company=identity.company,
user=user,
company=company,
created=now,
last_update=now,
last_change=now,
last_changed_by=user,
**fields,
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
def validate_input_models(task, allow_only_public=False):
if not task.models.input:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
model_ids = set(m.model for m in task.models.input)
models = Model.objects(
Q(id__in=model_ids) & get_company_or_none_constraint(company)
).only("id")
missing = model_ids - {m.id for m in models}
if missing:
raise errors.bad_request.InvalidModelId(models=missing)
return model
return
@classmethod
def clone_task(
@@ -179,26 +159,65 @@ class TaskBLL:
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
container: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
hyperparams_overrides: Optional[dict] = None,
configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
task: Task = cls.get_by_id(
company_id=company_id, task_id=task_id, allow_public=True
)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
params_dict = {}
if hyperparams:
params_dict["hyperparams"] = hyperparams
elif hyperparams_overrides:
updated_hyperparams = {
sec: {k: value for k, value in sec_data.items()}
for sec, sec_data in (task.hyperparams or {}).items()
}
for section, section_data in hyperparams_overrides.items():
for key, value in section_data.items():
nested_set(updated_hyperparams, (section, key), value)
params_dict["hyperparams"] = updated_hyperparams
if configuration:
params_dict["configuration"] = configuration
elif configuration_overrides:
updated_configuration = {
k: value for k, value in (task.configuration or {}).items()
}
for key, value in configuration_overrides.items():
updated_configuration[key] = value
params_dict["configuration"] = updated_configuration
now = datetime.utcnow()
if input_models:
input_models = [
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
]
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
execution_model_overriden = execution_overrides.get("model") is not None
execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model:
input_models = [
ModelItem(
model=execution_model,
name=TaskModelNames[TaskModelTypes.input],
updated=now,
)
]
docker_cmd = execution_overrides.pop("docker_cmd", None)
if not container and docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
container = {"image": image, "arguments": arguments}
artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {}
@@ -207,6 +226,8 @@ class TaskBLL:
if legacy_value is not None:
params_dict["execution"] = legacy_value
escape_dict_field(execution_overrides, "model_labels")
execution_dict.update(execution_overrides)
params_prepare_for_save(params_dict, previous_task=task)
@@ -216,7 +237,7 @@ class TaskBLL:
execution_dict["artifacts"] = {
k: a
for k, a in artifacts.items()
if a.get("mode") != ArtifactModes.output
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
}
execution_dict.pop("queue", None)
@@ -227,12 +248,10 @@ class TaskBLL:
project_name=new_project_name,
user=user_id,
company=company_id,
description="Auto-generated while cloning",
description="",
)
new_project_data = {"id": project, "name": new_project_name}
now = datetime.utcnow()
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
if not input_tags:
return input_tags
@@ -240,54 +259,70 @@ class TaskBLL:
return [
tag
for tag in input_tags
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
if tag
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination)
if task.output
else None,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
def ensure_int_labels(execution: dict) -> dict:
if not execution:
return execution
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
update_project_time(new_task.project)
model_labels = execution.get("model_labels")
if model_labels:
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
return execution
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else task.id
)
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
name=name or task.name,
comment=comment or task.comment,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=ensure_int_labels(execution_dict),
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
projects=[new_task.project],
tags=updated_tags,
system_tags=updated_system_tags,
)
update_project_time(new_task.project)
return new_task, new_project_data
@@ -295,7 +330,7 @@ class TaskBLL:
def validate(
cls,
task: Task,
validate_model=True,
validate_models=True,
validate_parent=True,
validate_project=True,
):
@@ -307,6 +342,7 @@ class TaskBLL:
if (
validate_parent
and task.parent
and not task.parent.startswith(deleted_prefix)
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
@@ -318,60 +354,21 @@ class TaskBLL:
if validate_project and not project:
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]},
**({"project": {"$in": project_ids}} if project_ids else {}),
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
if validate_models:
cls.validate_input_models(task)
@staticmethod
def set_last_update(
task_ids: Collection[str],
company_id: str,
user_id: str,
last_update: datetime,
**extra_updates,
):
tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
count = 0
for task in tasks:
updates = extra_updates
if task.status == TaskStatus.in_progress and task.started:
@@ -381,21 +378,24 @@ class TaskBLL:
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
count += Task.objects(id=task.id, company=company_id).update(
upsert=False,
last_update=last_update,
last_change=last_update,
last_changed_by=user_id,
**updates,
)
return count
@staticmethod
def update_statistics(
task_id: str,
company_id: str,
user_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
@@ -420,25 +420,21 @@ class TaskBLL:
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_scalar_values is not None:
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
for path, value in last_scalar_values:
if path[-1] == "min_value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
elif path[-1] == "max_value":
extra_updates[op_path("max", *path[:-1], "max_value")] = value
else:
extra_updates[op_path("set", *path)] = value
raw_updates = {}
if last_scalar_events is not None:
get_last_metric_updates(
task_id=task_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=extra_updates,
)
if last_events is not None:
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
def events_per_type(metric_data_: Dict[str, dict]) -> Dict[str, EventStats]:
return {
event_type: EventStats(last_update=event["timestamp"])
for event_type, event in metric_data.items()
for event_type, event in metric_data_.items()
}
metric_stats = {
@@ -449,241 +445,67 @@ class TaskBLL:
}
extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update(
ret = TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
user_id=user_id,
last_update=last_update,
**extra_updates,
)
if ret and raw_updates:
Task.objects(id=task_id).update_one(__raw__=[{"$set": raw_updates}])
@classmethod
def model_set_ready(
cls,
model_id: str,
company_id: str,
publish_task: bool,
force_publish_task: bool = False,
) -> tuple:
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
elif model.ready:
raise errors.bad_request.ModelIsReady(**query)
published_task_data = {}
if model.task and publish_task:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
published_task_data["data"] = cls.publish_task(
task_id=model.task,
company_id=company_id,
publish_model=False,
force=force_publish_task,
)
published_task_data["id"] = model.task
updated = model.update(upsert=False, ready=True)
return updated, published_task_data
@classmethod
def publish_task(
cls,
task_id: str,
company_id: str,
publish_model: bool,
force: bool,
status_reason: str = "",
status_message: str = "",
) -> dict:
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
@classmethod
def stop_task(
cls,
task_id: str,
company_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()
return ret
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
def remove_task_from_all_queues(
company_id: str, task_id: str, exclude: str = None
) -> int:
more = {}
if exclude:
more["id__ne"] = exclude
return Queue.objects(company=company_id, entries__task=task_id, **more).update(
pull__entries__task=task_id, last_update=datetime.utcnow()
)
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,
cls,
task: Task,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
remove_from_all_queues=False,
new_status=None,
new_status_for_aborted_task=None,
):
cls.dequeue(task, company_id)
try:
cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True)
except APIError:
# dequeue may fail if the queue was deleted
pass
if remove_from_all_queues:
cls.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
if task.status not in [TaskStatus.queued, TaskStatus.in_progress]:
return {"updated": 0}
if new_status_for_aborted_task and task.status == TaskStatus.in_progress:
new_status = new_status_for_aborted_task
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
new_status=new_status or task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
user_id=user_id,
force=True,
).execute(enqueue_status=None)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
@@ -710,6 +532,9 @@ class TaskBLL:
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
company_id=company_id,
user_id=user_id,
queue_id=task.execution.queue,
task_id=task.id,
)
}

View File

@@ -0,0 +1,374 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Set, Tuple, Union
import attr
from boltons.iterutils import partition, bucketize, first, chunked_iter
from furl import furl
from mongoengine import NotUniqueError
from pymongo.errors import DuplicateKeyError
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.database.model.url_to_delete import (
StorageType,
UrlToDelete,
FileType,
DeletionStatus,
)
from apiserver.database.utils import id as db_id
log = config.logger(__file__)
event_bll = EventBLL()
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
artifact_urls: Sequence[str]
event_urls: Sequence[str] = [] # left here is in order not to break the api
def __add__(self, other: "TaskUrls"):
if not other:
return self
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@attr.s(auto_attribs=True)
class CleanupResult:
"""
Counts of objects modified in task cleanup operation
"""
updated_children: int
updated_models: int
deleted_models: int
deleted_model_ids: Set[str]
urls: TaskUrls = None
def to_res_dict(self, return_file_urls: bool) -> dict:
remove_fields = ["deleted_model_ids"]
if not return_file_urls:
remove_fields.append("urls")
# noinspection PyTypeChecker
res = attr.asdict(
self, filter=lambda attrib, value: attrib.name not in remove_fields
)
if not return_file_urls:
res["urls"] = None
return res
def __add__(self, other: "CleanupResult"):
if not other:
return self
return CleanupResult(
updated_children=self.updated_children + other.updated_children,
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
deleted_model_ids=self.deleted_model_ids | other.deleted_model_ids,
)
@staticmethod
def empty():
return CleanupResult(
updated_children=0,
updated_models=0,
deleted_models=0,
deleted_model_ids=set(),
)
def collect_plot_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
urls = set()
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_ids=tasks, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
urls = set()
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
after_key = None
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_ids=tasks,
after_key=after_key,
)
urls.update(res)
if not after_key:
break
return urls
supported_storage_types = {
"s3://": StorageType.s3,
"azure://": StorageType.azure,
"gs://": StorageType.gs,
}
supported_storage_types.update(
{
p: StorageType.fileserver
for p in config.get(
"services.async_urls_delete.fileserver.url_prefixes",
["https://", "http://"],
)
}
)
def schedule_for_delete(
company: str,
user: str,
task_id: str,
urls: Set[str],
can_delete_folders: bool,
) -> Set[str]:
urls_per_storage = bucketize(
urls,
key=lambda u: first(
type_
for prefix, type_ in supported_storage_types.items()
if u.startswith(prefix)
),
)
urls_per_storage.pop(None, None)
processed_urls = set()
for storage_type, storage_urls in urls_per_storage.items():
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
scheduled_to_delete = set()
for url in storage_urls:
folder = None
if delete_folders:
try:
parsed = furl(url)
if parsed.path and len(parsed.path.segments) > 1:
folder = parsed.remove(
args=True, fragment=True, path=parsed.path.segments[-1]
).url.rstrip("/")
except Exception as ex:
pass
to_delete = folder or url
if to_delete in scheduled_to_delete:
processed_urls.add(url)
continue
try:
UrlToDelete(
id=db_id(),
company=company,
user=user,
url=to_delete,
task=task_id,
created=datetime.utcnow(),
storage_type=storage_type,
type=FileType.folder if folder else FileType.file,
).save()
except (DuplicateKeyError, NotUniqueError):
existing = UrlToDelete.objects(company=company, url=to_delete).first()
if existing:
existing.update(
user=user,
task=task_id,
created=datetime.utcnow(),
retry_count=0,
unset__last_failure_time=1,
unset__last_failure_reason=1,
status=DeletionStatus.created,
)
processed_urls.add(url)
scheduled_to_delete.add(to_delete)
return processed_urls
def delete_task_events_and_collect_urls(
company: str, task_ids: Sequence[str], wait_for_delete: bool, model=False
) -> Set[str]:
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
company, task_ids
)
event_bll.delete_task_events(
company, task_ids, model=model, wait_for_delete=wait_for_delete
)
return event_urls
def cleanup_task(
company: str,
user: str,
task: Task,
force: bool = False,
update_children=True,
delete_output_models=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
:param task: task object
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
artifact_urls = (
{
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
if task.execution and task.execution.artifacts
else {}
)
model_urls = {m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids}
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
now = datetime.utcnow()
if update_children:
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id,
last_change=now,
last_changed_by=user,
)
deleted_models = 0
updated_models = 0
deleted_model_ids = set()
for models, allow_delete in ((draft_models, True), (published_models, False)):
if not models:
continue
if delete_output_models and allow_delete:
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
if model_ids:
deleted_models += Model.objects(id__in=model_ids).delete()
deleted_model_ids.update(model_ids)
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
continue
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id,
last_change=now,
last_changed_by=user,
)
else:
Model.objects(id__in=[m.id for m in models]).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
),
deleted_model_ids=deleted_model_ids,
)
def verify_task_children_and_ouptuts(
task, force: bool
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
if not force:
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(
Model.objects(task=task.id).only(*model_fields),
key=attrgetter("ready"),
)
if not force and published_models:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(published_models),
)
if task.models and task.models.output:
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids).only(*model_fields):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
published_models.append(output_model)
else:
draft_models.append(output_model)
in_use_model_ids = {}
if draft_models:
model_ids = {m.id for m in draft_models}
dependent_tasks = Task.objects(models__input__model__in=list(model_ids)).only(
"id", "models"
)
in_use_model_ids = model_ids & {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
}
return published_models, draft_models, in_use_model_ids

View File

@@ -0,0 +1,612 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union, Sequence
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
TaskStatus,
Task,
TaskSystemTags,
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
queue_bll = QueueBLL()
def _get_pipeline_steps_for_controller_task(
task: Task, company_id: str, only: Sequence[str] = None
) -> Sequence[Task]:
if not task or task.type != TaskType.controller:
return []
query = Task.objects(company=company_id, parent=task.id)
if only:
query = query.only(*only)
return list(query)
def archive_task(
task: Union[str, Task],
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
user_id = identity.user
fields = (
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
"type",
)
if isinstance(task, str):
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=fields,
)
def archive_task_core(task_: Task) -> int:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
new_status_for_aborted_task=TaskStatus.stopped,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task_.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
archive_task_core(step)
return archive_task_core(task)
def unarchive_task(
task_id: str,
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
fields = ("id", "type")
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=fields,
)
def unarchive_task_core(task_: Task) -> int:
return task_.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
unarchive_task_core(step)
return unarchive_task_core(task)
def dequeue_task(
task_id: str,
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
remove_from_all_queues: bool = False,
new_status=None,
) -> Tuple[int, dict]:
if new_status and new_status not in get_options(TaskStatus):
raise errors.bad_request.ValidationError(f"Invalid task status: {new_status}")
# get the task without write access to make sure that it actually exists
task = Task.get(
id=task_id,
company=company_id,
_only=("id",),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"enqueue_status",
),
)
res = TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=remove_from_all_queues,
new_status=new_status,
)
return 1, res
def enqueue_task(
task_id: str,
company_id: str,
identity: Identity,
queue_id: str,
status_message: str,
status_reason: str,
queue_name: str = None,
validate: bool = False,
force: bool = False,
update_execution_queue: bool = True,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
if not update_execution_queue:
if not (
task.status == TaskStatus.queued and task.execution and task.execution.queue
):
raise errors.bad_request.ValidationError(
"Cannot skip setting execution queue for a task "
"that is not enqueued or does not have execution queue set"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
)
if not queue:
queue = queue_bll.create(company_id=company_id, name=queue_name)
queue_id = queue.id
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
user_id = identity.user
if validate:
TaskBLL.validate(task)
before_enqueue_status = task.status
if task.status == TaskStatus.queued and task.enqueue_status:
before_enqueue_status = task.enqueue_status
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute(enqueue_status=before_enqueue_status)
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
except Exception:
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
user_id=user_id,
).execute(enqueue_status=None)
raise
# set the current queue ID in the task
if update_execution_queue:
if task.execution:
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(id=task_id).update(
execution=Execution(queue=queue_id), multi=False
)
nested_set(res, ("fields", "execution.queue"), queue_id)
# make sure that the task is not queued in any other queue
TaskBLL.remove_task_from_all_queues(
company_id=company_id, task_id=task_id, exclude=queue_id
)
return 1, res
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
try:
collection_name = Task._get_collection_name()
trash_collection_name = f"{collection_name}__trash"
Task.aggregate(
[
{"$match": {"_id": {"$in": tasks}}},
{
"$merge": {
"into": trash_collection_name,
"on": "_id",
"whenMatched": "replace",
"whenNotMatched": "insert",
}
},
],
allow_disk_use=True,
)
except Exception as ex:
log.error(f"Error copying tasks to trash {str(ex)}")
return Task.objects(id__in=tasks).delete()
def delete_task(
task_id: str,
company_id: str,
identity: Identity,
move_to_trash: bool,
force: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if (
task.status != TaskStatus.created
and EntityVisibility.archived.value not in task.system_tags
and not force
):
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
def delete_task_core(task_: Task, force_: bool) -> CleanupResult:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
res = cleanup_task(
company=company_id,
user=user_id,
task=task_,
force=force_,
delete_output_models=delete_output_models,
)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task_.last_update = datetime.utcnow()
task_.save()
else:
task_.delete()
return res
task_ids = [task.id]
cleanup_res = CleanupResult.empty()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
cleanup_res += delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
if move_to_trash:
move_tasks_to_trash(task_ids)
update_project_time(task.project)
return 1, task, cleanup_res
def reset_task(
task_id: str,
company_id: str,
identity: Identity,
force: bool,
delete_output_models: bool,
clear_all: bool,
) -> Tuple[dict, CleanupResult, dict]:
user_id = identity.user
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
dequeued = {}
updates = {}
try:
dequeued = TaskBLL.dequeue(
task, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
TaskBLL.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
cleaned_up = cleanup_task(
company=company_id,
user=user_id,
task=task,
force=force,
update_children=False,
delete_output_models=delete_output_models,
)
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__unique_metrics=[],
set__metric_stats={},
set__models__output=[],
set__runtime={},
unset__output__result=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
unset__started=1,
unset__completed=1,
unset__published=1,
unset__active_duration=1,
unset__enqueue_status=1,
)
if clear_all:
updates.update(
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts.items()
if artifact.mode == ArtifactModes.input
}
)
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
force=force,
status_reason="reset",
status_message="reset",
user_id=user_id,
).execute(
**updates,
)
return dequeued, cleaned_up, res
def publish_task(
task_id: str,
company_id: str,
identity: Identity,
force: bool,
publish_model_func: Callable[[str, str, Identity], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
user_id = identity.user
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.models and task.models.output and publish_model_func:
model_id = task.models.output[-1].model
model = (
Model.objects(id=model_id, company=company_id)
.only("id", "ready")
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id, identity)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
def stop_task(
task_id: str,
company_id: str,
identity: Identity,
user_name: str,
status_reason: str,
force: bool,
include_pipeline_steps: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
user_id = identity.user
fields = (
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
"type",
)
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=fields,
)
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
def stop_task_core(task_: Task, force_: bool):
is_queued = task_.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task_.system_tags
or not is_run_by_worker(task_)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(
task_, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task_.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task_,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force_,
user_id=user_id,
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
stop_task_core(step, True)
return stop_task_core(task, force)

View File

@@ -1,18 +1,23 @@
from datetime import datetime
from typing import TypeVar, Callable, Tuple, Sequence, Union
from typing import Sequence
import attr
import six
from mongoengine import Q
from mongoengine.base import UPDATE_OPERATORS
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from apiserver.service_repo.auth import Identity
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
deleted_prefix = "__DELETED__"
@typed_attrs
@@ -26,6 +31,7 @@ class ChangeStatusRequest(object):
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None)
user_id = attr.ib(type=str, default=None)
def execute(self, **kwargs):
current_status = self.current_status_override or self.task.status
@@ -44,6 +50,7 @@ class ChangeStatusRequest(object):
status_changed=now,
last_update=now,
last_change=now,
last_changed_by=self.user_id,
)
if self.new_status == TaskStatus.queued:
@@ -54,7 +61,7 @@ class ChangeStatusRequest(object):
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context(), TimingContext("mongo", "task_status"):
with translate_errors_context():
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)
@@ -72,8 +79,16 @@ class ChangeStatusRequest(object):
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
def is_mongo_operator(field: str) -> bool:
head, _, tail = field.partition("__")
return tail and (head in UPDATE_OPERATORS)
# make sure to not return _raw_ queries or any of the update operators
fields = {
key: value
for key, value in fields.items()
if not (key == "__raw__" or is_mongo_operator(key))
}
return dict(updated=updated, fields=fields)
@@ -105,7 +120,7 @@ def validate_status_change(current_status, new_status):
state_machine = {
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress, TaskStatus.stopped},
TaskStatus.in_progress: {
TaskStatus.stopped,
TaskStatus.failed,
@@ -116,6 +131,7 @@ state_machine = {
TaskStatus.closed,
TaskStatus.created,
TaskStatus.failed,
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.published,
TaskStatus.publishing,
@@ -128,7 +144,12 @@ state_machine = {
TaskStatus.publishing,
TaskStatus.stopped,
},
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
TaskStatus.failed: {
TaskStatus.created,
TaskStatus.stopped,
TaskStatus.published,
TaskStatus.queued,
},
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
TaskStatus.completed: {
@@ -153,41 +174,78 @@ def get_possible_status_changes(current_status):
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
def get_many_tasks_for_writing(
company_id: str,
identity: Identity,
query: Q = None,
only: Sequence = None,
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company",) if f not in only]
if missing:
only = [*only, *missing]
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
T = TypeVar("T")
def split_by(
condition: Callable[[T], bool], items: Sequence[T]
) -> Tuple[Sequence[T], Sequence[T]]:
"""
split "items" to two lists by "condition"
"""
applied = zip(map(condition, items), items)
return (
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
result = list(
Task.get_many(
company=company_id,
query=query,
override_projection=only,
allow_public=True,
return_dicts=False,
)
)
if not company_id:
return result
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:
raise errors.forbidden.NoWritePermission(
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
)
result = [task for task in result if task.id not in forbidden_tasks]
return result
def get_task_with_write_access(
task_id: str,
company_id: str,
identity: Identity,
only=None,
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(_only=only, **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str,
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False,
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
only=("id", "status"),
identity=identity,
)
if allow_all_statuses:
return task
@@ -202,9 +260,152 @@ def get_task_for_update(
return task
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
def update_task(
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
):
now = datetime.utcnow()
last_updates = dict(last_change=now)
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)
def get_last_metric_updates(
task_id: str,
last_scalar_events: dict,
raw_updates: dict,
extra_updates: dict,
model_events: bool = False,
):
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values - to_add}__exists"] = True
db_cls = Model if model_events else Task
task = db_cls.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_mean_update(
metric_path: str,
metric_count: int,
metric_total: float,
):
"""
Update new mean field based on the value in db and new data
The count field is updated here too and not with inc__ so that
it will not get updated in the db earlier than the corresponding mean
"""
metric_path = metric_path.replace("__", ".")
mean_value_field = f"{metric_path}.mean_value"
count_field = f"{metric_path}.count"
raw_updates[mean_value_field] = {
"$round": [
{
"$divide": [
{
"$add": [
{
"$multiply": [
{"$ifNull": [f"${mean_value_field}", 0]},
{"$ifNull": [f"${count_field}", 0]},
]
},
metric_total,
]
},
{
"$add": [
{"$ifNull": [f"${count_field}", 0]},
metric_count,
]
},
]
},
2,
]
}
raw_updates[count_field] = {
"$add": [
{"$ifNull": [f"${count_field}", 0]},
metric_count,
]
}
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool, is_first: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_first:
field_prefix = "first"
op = None
elif is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
exists = {"$lte": [f"${value_field}", None]}
if op:
condition = {
"$or": [
exists,
{op: [f"${value_field}", metric_value]},
]
}
else:
condition = exists
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = (
f"{metric_path}__{field_prefix}_value_iteration".replace("__", ".")
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
if max_values:
if len(total_metrics) >= max_values and metric not in total_metrics:
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value", "first_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
is_first=(key == "first_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
count = variant_data.get("count")
total = variant_data.get("total")
if count is not None and total is not None:
add_last_metric_mean_update(
metric_path=path,
metric_count=count,
metric_total=total,
)
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics

View File

@@ -1,5 +1,8 @@
from datetime import datetime
from apiserver.apierrors import errors
from apiserver.apimodels.users import CreateRequest
from apiserver.config.info import get_version
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.user import User
@@ -12,7 +15,11 @@ class UserBLL:
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
user = User(**request.to_struct())
user = User(
**request.to_struct(),
created=datetime.utcnow(),
created_in_version=get_version(),
)
user.save(force_insert=True)
@staticmethod

View File

@@ -1,84 +1,24 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from operator import itemgetter
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
from datetime import datetime
from typing import (
Optional,
Callable,
Iterable,
Tuple,
Sequence,
TypeVar,
Union,
)
from boltons import iterutils
from apiserver.database.model import AttributedDocument
from apiserver.apierrors import APIError
from apiserver.database.model.project import Project
from apiserver.database.model.settings import Settings
def extract_properties_to_lists(
key_names: Sequence[str],
data: Sequence[dict],
extract_func: Optional[Callable[[dict], Tuple]] = None,
) -> dict:
"""
Given a list of dictionaries and names of dictionary keys
builds a dictionary with the requested keys and values lists
:param key_names: names of the keys in the resulting dictionary
:param data: sequence of dictionaries to extract values from
:param extract_func: the optional callable that extracts properties
from a dictionary and put them in a tuple in the order corresponding to
key_names. If not specified then properties are extracted according to key_names
"""
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))
class SetFieldsResolver:
"""
The class receives set fields dictionary
and for the set fields that require 'min' or 'max'
operation replace them with a simple set in case the
DB document does not have these fields set
"""
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = {}
self.fields = {}
self.add_fields(**set_fields)
def add_fields(self, **set_fields: Any):
self.orig_fields.update(set_fields)
self.fields.update(
{
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
return self.fields[name]
return name
def get_fields(self, doc: AttributedDocument):
"""
For the given document return the set fields instructions
with min/max operations replaced with a single set in case
the document does not have the field set
"""
return {
self._get_updated_name(doc, name): value
for name, value in self.orig_fields.items()
}
def get_names(self) -> Set[str]:
"""
Returns the names of the fields that had min/max modifiers
in the format suitable for projection (dot separated)
"""
return set(name.replace("__", ".") for name in self.fields.values())
@functools.lru_cache()
def get_server_uuid() -> Optional[str]:
return Settings.get_by_key("server.uuid")
@@ -115,3 +55,38 @@ def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
)
return wrapper
T = TypeVar("T")
def run_batch_operation(
func: Callable[[str], T], ids: Sequence[str]
) -> Tuple[Sequence[Tuple[str, T]], Sequence[dict]]:
results = list()
failures = list()
for _id in ids:
try:
results.append((_id, func(_id)))
except APIError as err:
failures.append(
{
"id": _id,
"error": {
"codes": [err.code, err.subcode],
"msg": err.msg,
"data": err.error_data,
},
}
)
return results, failures
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())

View File

@@ -1,15 +1,18 @@
import itertools
import re
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition, chunked_iter
from pyhocon import ConfigTree
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
from apiserver.apierrors.errors import bad_request, server_error
from apiserver.apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry,
WorkerEntry,
StatusReportRequest,
@@ -25,16 +28,18 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
_key_regex_trans = str.maketrans({"*": ".*", "?": ".?"})
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@@ -51,6 +56,7 @@ class WorkerBLL:
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@@ -66,7 +72,7 @@ class WorkerBLL:
"""
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
queues = queues or []
with translate_errors_context():
@@ -76,7 +82,7 @@ class WorkerBLL:
raise bad_request.InvalidUserId(**query)
company = Company.objects(id=company_id).only("id", "name").first()
if not company:
raise server_error.InternalError("invalid company", company=company_id)
raise bad_request.InvalidId("invalid company", company=company_id)
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
if len(queue_objs) < len(queues):
@@ -95,9 +101,10 @@ class WorkerBLL:
register_timeout=timeout,
last_activity_time=now,
tags=tags,
system_tags=system_tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
self._save_worker_data(entry)
return entry
@@ -109,15 +116,20 @@ class WorkerBLL:
:param worker: worker ID
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
"""
with TimingContext("redis", "workers_unregister"):
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res:
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res and not config.get("apiserver.workers.auto_unregister", False):
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
self,
company_id: str,
user_id: str,
ip: str,
report: StatusReportRequest,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
@@ -133,22 +145,23 @@ class WorkerBLL:
try:
entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
if system_tags is not None:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
self.log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=report.worker,
worker_id=report.worker,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
)
now = datetime.utcnow()
entry.last_activity_time = now
entry.queue = report.queue
if report.queues:
@@ -165,6 +178,7 @@ class WorkerBLL:
last_worker_report=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)
@@ -176,7 +190,9 @@ class WorkerBLL:
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.project = IdNameEntry(
id=project.id, name=project.name
)
entry.last_report_time = now
except APIError:
@@ -188,8 +204,41 @@ class WorkerBLL:
finally:
self._save_worker(entry)
def get_count(
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
):
if not last_seen:
return len(
self._get_keys(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
return len(
self.get_all(
company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
def get_all(
self, company_id: str, last_seen: Optional[int] = None
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -198,7 +247,12 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id)
workers = self._get(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@@ -213,15 +267,23 @@ class WorkerBLL:
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int
self,
company_id: str,
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen),
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
for entry in self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
]
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set(
@@ -240,14 +302,12 @@ class WorkerBLL:
}
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
task_ids = task_ids.union(
filter(
None,
(
safe_get(info, "next_entry/task")
nested_get(info, ("next_entry", "task"))
for info in queues_info.values()
),
)
@@ -258,7 +318,7 @@ class WorkerBLL:
tasks_info = {
task.id: task
for task in Task.objects(id__in=task_ids).only(
"name", "started", "last_iteration"
"name", "started", "last_iteration", "active_duration"
)
}
@@ -271,7 +331,7 @@ class WorkerBLL:
continue
entry.name = info.get("name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = safe_get(info, "next_entry/task")
task_id = nested_get(info, ("next_entry", "task"))
if task_id:
task = tasks_info.get(task_id, None)
entry.next_task = IdNameEntry(
@@ -281,13 +341,9 @@ class WorkerBLL:
for helper in helpers:
worker = helper.worker
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
task: Task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (
int((datetime.utcnow() - task.started).total_seconds() * 1000)
if task.started
else 0
)
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
update_queue_entries(worker.queue)
@@ -314,8 +370,7 @@ class WorkerBLL:
"""
key = self._get_worker_key(company_id, user_id, worker)
with TimingContext("redis", "get_worker"):
data = self.redis.get(key)
data = self.redis.get(key)
if data:
try:
@@ -342,42 +397,163 @@ class WorkerBLL:
raise bad_request.InvalidWorkerId(worker=worker)
@staticmethod
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers.{tags_field}_{company}_{tag}"
@staticmethod
def _get_all_workers_key(company: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers_{company}"
def _save_worker_data(self, entry: WorkerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
company_id = entry.company.id
expiration = int(time()) + entry.register_timeout
worker_item = {entry.key: expiration}
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
for tags, tags_field in (
(entry.tags, "tags"),
(entry.system_tags, "systemtags"),
):
for tag in tags:
name = self._get_tagged_workers_key(company_id, tags_field, tag)
self.redis.zadd(name, worker_item)
def _save_worker(self, entry: WorkerEntry) -> None:
"""Save worker entry in Redis"""
try:
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
self._save_worker_data(entry)
except Exception:
msg = "Failed saving worker entry"
log.exception(msg)
def _get_keys(
self,
company: str,
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[bytes]:
if not (user_tags or system_tags):
match = self._get_worker_key(company, user, worker_pattern or "*")
return list(self.redis.scan_iter(match))
def filter_by_user_and_pattern(in_keys: Set[bytes]) -> Set[bytes]:
if user != "*":
user_bytes = user.encode()
in_keys = {k for k in in_keys if user_bytes in k}
if worker_pattern:
worker_pattern_bytes = (
f"{worker_pattern.translate(self._key_regex_trans)}$".encode()
)
regex = re.compile(worker_pattern_bytes)
in_keys = {k for k in in_keys if regex.search(k)}
return in_keys
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
tagged_workers = filter_by_user_and_pattern(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user_and_pattern(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
return list(worker_keys)
def _get(
self, company: str, user: str = "*", worker_id: str = "*"
self,
company: str,
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
match = self._get_worker_key(company, user, worker_id)
with TimingContext("redis", "workers_get_all"):
res = self.redis.scan_iter(match)
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
entries = []
for keys in chunked_iter(
self._get_keys(
company,
user=user,
user_tags=user_tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
),
1000,
):
data = self.redis.mget(keys)
if data:
entries.extend(WorkerEntry.from_json(d) for d in data if d)
return entries
@staticmethod
def _get_es_index_suffix():
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es(
def log_stats_to_es(
self,
company_id: str,
company_name: str,
worker: str,
worker_id: str,
timestamp: int,
task: str,
machine_stats: MachineStats,
) -> bool:
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise
:return: The amount of logged documents
"""
es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
@@ -389,8 +565,7 @@ class WorkerBLL:
_index=es_index,
_source=dict(
timestamp=timestamp,
worker=worker,
company=company_name,
worker=worker_id,
task=task,
category=category,
metric=metric,
@@ -415,7 +590,7 @@ class WorkerBLL:
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return added
@attr.s(auto_attribs=True)

View File

@@ -8,19 +8,20 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
class WorkerStats:
min_chart_interval = config.get("services.workers.min_chart_interval_sec", 40)
def __init__(self, es):
self.es = es
@staticmethod
def worker_stats_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id}_"
return f"worker_stats_{company_id.lower()}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
@@ -72,9 +73,13 @@ class WorkerStats:
Buckets with no metrics are not returned
Note: all the statistics are retrieved as one ES query
"""
if request.from_date >= request.to_date:
from_date = request.from_date
to_date = request.to_date
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(request.interval, self.min_chart_interval)
def get_dates_agg() -> dict:
es_to_agg_types = (
("avg", AggregationType.avg.value),
@@ -86,8 +91,11 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{request.interval}s",
"min_doc_count": 1,
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {
agg_type: {es_agg: {"field": "value"}}
@@ -119,26 +127,26 @@ class WorkerStats:
}
query_terms = [
QueryBuilder.dates_range(request.from_date, request.to_date),
QueryBuilder.dates_range(from_date, to_date),
QueryBuilder.terms("metric", {item.key for item in request.items}),
]
if request.worker_ids:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
cutoff_date = (to_date - 0.9 * interval) * 1000 # do not return the point for the incomplete last interval
return self._extract_results(data, request.items, request.split_by_variant, cutoff_date)
@staticmethod
def _extract_results(
data: dict, request_items: Sequence[StatItem], split_by_variant: bool
data: dict, request_items: Sequence[StatItem], split_by_variant: bool, cutoff_date
) -> dict:
"""
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
leave only aggregation types requested by the user and return a clean dictionary
and return a "clean" dictionary of
:param data: aggregation data retrieved from ES
:param request_items: aggs types requested by the user
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
@@ -156,7 +164,7 @@ class WorkerStats:
return {
"date": date["key"],
"count": date["doc_count"],
**{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
**{agg: date[agg]["value"] or 0.0 for agg in aggs_per_metric[metric_key]},
}
def extract_metric_results(
@@ -165,7 +173,7 @@ class WorkerStats:
return [
extract_date_stats(date, metric_key)
for date in metric_or_variant["dates"]["buckets"]
if date["doc_count"]
if date["key"] <= cutoff_date
]
def extract_variant_results(metric: dict) -> dict:
@@ -204,6 +212,7 @@ class WorkerStats:
"""
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(interval, self.min_chart_interval)
must = [QueryBuilder.dates_range(from_date, to_date)]
if active_only:
@@ -216,6 +225,10 @@ class WorkerStats:
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}
@@ -223,9 +236,7 @@ class WorkerStats:
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext(
"es", "get_worker_activity_report"
):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
if "aggregations" not in data:

View File

@@ -6,9 +6,10 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, TypeVar
from typing import List, Any, TypeVar, Sequence, Set
from pyhocon import ConfigTree, ConfigFactory
from boltons.iterutils import first
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
from pyparsing import (
ParseFatalException,
ParseException,
@@ -18,8 +19,8 @@ from pyparsing import (
from apiserver.utilities import json
EXTRA_CONFIG_PATHS = ("/opt/trains/config",)
EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR"
EXTRA_CONFIG_PATHS = ("/opt/trains/config", "/opt/clearml/config")
DEFAULT_PREFIXES = ("clearml", "trains")
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";"
@@ -30,7 +31,11 @@ class BasicConfig:
default_config_dir = "default"
def __init__(
self, folder: str = None, verbose: bool = True, prefix: str = "trains"
self,
folder: str = None,
verbose: bool = True,
prefix: Sequence[str] = DEFAULT_PREFIXES,
exclude_files_from_base_folder: Sequence[str] = None,
):
folder = (
Path(folder)
@@ -40,9 +45,22 @@ class BasicConfig:
if not folder.is_dir():
raise ValueError("Invalid configuration folder")
self.exclude_files_from_base_folder = (
set(exclude_files_from_base_folder)
if exclude_files_from_base_folder
else set()
)
self.verbose = verbose
self.prefix = prefix
self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__"
self.extra_config_path_override_var = [
f"{p.upper()}_CONFIG_DIR" for p in prefix
]
self.prefix = prefix[0]
self.extra_config_values_env_key_prefix = [
f"{p.upper()}{self.extra_config_values_env_key_sep}"
for p in reversed(prefix)
]
self._paths = [folder, *self._get_paths()]
self._config = self._reload()
@@ -67,30 +85,32 @@ class BasicConfig:
def logger(self, name: str) -> logging.Logger:
if Path(name).is_file():
name = Path(name).stem
if name == "__init__" and Path(name).parent.stem:
name = Path(name).parent.stem
path = ".".join((self.prefix, name))
return logging.getLogger(path)
def _read_extra_env_config_values(self) -> ConfigTree:
""" Loads extra configuration from environment-injected values """
"""Loads extra configuration from environment-injected values"""
result = ConfigTree()
prefix = self.extra_config_values_env_key_prefix
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = (
key[len(prefix) :]
.replace(self.extra_config_values_env_key_sep, ".")
.lower()
)
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
for prefix in self.extra_config_values_env_key_prefix:
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = (
key[len(prefix) :]
.replace(self.extra_config_values_env_key_sep, ".")
.lower()
)
result = self._merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
return result
def _get_paths(self) -> List[Path]:
default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS)
value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths)
value = first(map(getenv, self.extra_config_path_override_var), default_paths)
paths = [
Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP)
@@ -100,7 +120,7 @@ class BasicConfig:
invalid = [path for path in paths if not path.is_dir()]
if invalid:
print(
f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}"
f"WARNING: Invalid paths in {self.extra_config_path_override_var} env var: {' '.join(map(str, invalid))}"
)
return [path for path in paths if path.is_dir()]
@@ -111,17 +131,57 @@ class BasicConfig:
def _reload(self) -> ConfigTree:
extra_config_values = self._read_extra_env_config_values()
configs = [self._read_recursive(path) for path in self._paths]
configs = [
self._read_recursive(
path,
exclude_files=(
self.exclude_files_from_base_folder if idx == 0 else None
),
)
for idx, path in enumerate(self._paths)
]
return reduce(
lambda last, config: ConfigTree.merge_configs(
last, config, copy_trees=True
),
lambda last, config: self._merge_configs(last, config, copy_trees=True),
configs + [extra_config_values],
ConfigTree(),
)
def _read_recursive(self, conf_root) -> ConfigTree:
@classmethod
def _merge_configs(cls, a, b, copy_trees=False, override_prefix="-"):
"""Based on pyhocon.ConfigTree.merge_configs, with dict override support using a `-` key prefix"""
for key, value in b.items():
override = key.startswith(override_prefix)
if override:
key = key[len(override_prefix) :]
# if key is in both a and b and both values are dictionary then merge it otherwise override it
if (
not override
and key in a
and isinstance(a[key], ConfigTree)
and isinstance(b[key], ConfigTree)
):
if copy_trees:
a[key] = a[key].copy()
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
else:
if isinstance(value, ConfigValues):
value.parent = a
value.key = key
if key in a:
value.overriden_value = a[key]
a[key] = value
if a.root:
if b.root:
a.history[key] = a.history.get(key, []) + b.history.get(
key, [value]
)
else:
a.history[key] = a.history.get(key, []) + [value]
return a
def _read_recursive(self, conf_root, exclude_files: Set[str]) -> ConfigTree:
conf = ConfigTree()
if not conf_root:
@@ -139,6 +199,8 @@ class BasicConfig:
print(f"Loading config from {conf_root}")
for file in conf_root.rglob("*.conf"):
if exclude_files and file.name in exclude_files:
continue
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
conf.put(key, self._read_single_file(file))

View File

@@ -3,7 +3,7 @@
debug: false # Debug mode
pretty_json: false # prettify json response
return_stack: true # return stack trace on error
log_calls: true # Log API Calls
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
# valid values are:
@@ -41,10 +41,6 @@
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
strict: false
aggregate {
allow_disk_use: true
}
}
elastic {
@@ -62,6 +58,9 @@
# verify user tokens
verify_user_tokens: false
# If set then users that were created from secure credentials or fixed user settings and are no longer in these settings will be deleted on startup
delete_missing_autocreated_users: true
# max token expiration timeout in seconds (1 year)
max_expiration_sec: 31536000
@@ -69,19 +68,27 @@
default_expiration_sec: 2592000
# cookie containing auth token, for requests arriving from a web-browser
session_auth_cookie_name: "trains_token_basic"
session_auth_cookie_name: "clearml_token_basic"
# cookie configuration for authorization cookies generated by auth.login
cookies {
httponly: true # allow only http to access the cookies (no JS etc)
secure: false # not using HTTPS
domain: null # Limit to localhost is not supported
samesite: Lax
max_age: 99999999999
}
# provide a cookie domain override per company
# cookies_domain_override {
# <company-id>: <domain>
# }
# # A list of fixed users
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
# fixed_users {
# enabled: true
# pass_hashed: false
# users: [
# {
# username: "john"
@@ -105,9 +112,15 @@
workers {
# Auto-register unknown workers on status reports and other calls
auto_register: true
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
auto_unregister: true
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600
# Timeout in seconds for worker registration (or status report). If a worker did not report for this long,
# it is discarded from the server's table
default_timeout: 600
}
check_for_updates {
@@ -116,9 +129,9 @@
# Check for updates every 24 hours
check_interval_sec: 86400
url: "https://updates.trains.allegro.ai/updates"
url: "https://updates.clear.ml/updates"
component_name: "trains-server"
component_name: "clearml-server"
# GET request timeout
request_timeout_sec: 3.0
@@ -128,7 +141,7 @@
# Note: statistics are sent ONLY if the user has actively opted-in
supported: true
url: "https://updates.trains.allegro.ai/stats"
url: "https://updates.clear.ml/stats"
report_interval_hours: 24
agent_relevant_threshold_days: 30
@@ -137,4 +150,11 @@
max_backoff_sec: 5
}
getting_started_info {
"agentName": "clearml",
"configure": "clearml-init",
"install": "pip install clearml",
"packageName": "clearml"
}
}

View File

@@ -1,9 +1,10 @@
fileserver = "http://localhost:8081"
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
hosts: [{host: "127.0.0.1", port: 9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}
@@ -11,10 +12,9 @@ elastic {
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
hosts: [{host:"127.0.0.1", port:9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}

View File

@@ -16,7 +16,7 @@
backupCount: 3
maxBytes: 10240000,
class: "logging.handlers.RotatingFileHandler",
filename: "/var/log/trains/apiserver.log"
filename: "/var/log/clearml/apiserver.log"
}
}
root {

View File

@@ -1,13 +1,13 @@
{
http {
session_secret {
apiserver: "Gx*gB-L2U8!Naqzd#8=7A4&+=In4H(da424H33ZTDQRGF6=FWw"
apiserver: "V8gcW3EneNDcNfO7G_TSUsWe7uLozyacc9_I33o7bxUo8rCN31VLRg"
}
}
auth {
# token sign secret
token_secret: "7E1ua3xP9GT2(cIQOfhjp+gwN6spBeCAmN-XuugYle00I=Wc+u"
token_secret: "Rq8FW84sSqVgq7WvBB_4EzNl9y8z8IGiDXX3C345_a5AZfcwZcwCIA"
}
credentials {
@@ -15,19 +15,30 @@
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
user_secret: "gaOfhDX2-bpkeI7-cwEcaMuGijxaG2UG3jbIvg4DxmVGF0LNI7rgvCb1-ne38IlBo1w"
}
fileserver {
role: "system"
user_key: "GSQWPEKSKNKF354LC9V6BHXKTYFD5I"
user_secret: "tuBXcGQBECsEhcNiK2kiWi750z9r8Z85XrQ9V0c24huTuCb2xf2X1nKG"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
user_secret: "XhkH6a6ds9JBnM_MrahYyYdO-wS2bqFSm8gl-V0UZXH26Ydd6Eyi28TeBEoSr6Z3Bes"
revoke_in_fixed_mode: true
}
services_agent {
role: "admin"
user_key: ""
user_secret: ""
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
user_secret: "LPEJbGJ6bK4tujQcmrD3i1dbMBDdwUwelVa-LG0K0FFmY9bzH_H0Sw"
revoke_in_fixed_mode: true
}
}
}

View File

@@ -0,0 +1,9 @@
max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600
allow_disk_use {
sort: true
aggregate: true
}

View File

@@ -0,0 +1,13 @@
# if set to true then on task delete/reset external file urls for known storage types are scheduled for async delete
# otherwise they are returned to a client for the client side delete
enabled: true
max_retries: 3
retry_timeout_sec: 60
fileserver {
# fileserver url prefixes. Evaluated in the order of priority
# Can be in the form <schema>://host:port/path or /path
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
timeout_sec: 300
token_expiration_sec: 600
}

View File

@@ -12,11 +12,28 @@ events_retrieval {
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
# If set then max_metrics_count and max_variants_count are calculated dynamically on user data
dynamic_metrics_count: true
# The percentage from the ES aggs limit (10000) to use for the max_metrics and max_variants calculation
dynamic_metrics_count_threshold: 80
# the max amount of metrics to aggregate on
max_metrics_count: 100
# the max amount of variants to aggregate on
max_variants_count: 100
debug_images {
# Allow to return the debug images for the variants with uninitialized valid iterations border
allow_uninitialized_variants: true
}
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
multi_plots_batch_size: 1000
}
# if set then plot str will be checked for the valid json on plot add
@@ -24,4 +41,7 @@ events_retrieval {
validate_plot_str: false
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
plot_compression_threshold: 100000
plot_compression_threshold: 100000
# async events delete threshold
max_async_deleted_events_per_sec: 1000

View File

@@ -0,0 +1,4 @@
metadata_values {
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -1,3 +1,9 @@
tags_cache {
expiration_seconds: 3600
}
download {
redis_timeout_sec: 300
batch_size: 500
max_download_items: 50000
max_project_name_length: 60
}

View File

@@ -10,4 +10,9 @@ featured {
# default featured index for public projects not specified in the order
public_default: 9999
}
sub_projects {
# the max sub project depth
max_depth: 10
}

View File

@@ -0,0 +1,8 @@
{
metrics_before_from_date: 3600
# interval in seconds to update queue metrics. Put 0 to disable
metrics_refresh_interval_sec: 300
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
# or search_hidden is set
hidden_tags: [k8s-glue]
}

View File

@@ -0,0 +1,7 @@
default_container_timeout_sec: 600
# Auto-register unknown serving containers on status reports and other calls
container_auto_register: true
# Assume unknow serving containers have unregistered (i.e. do not raise unregistered error)
container_auto_unregister: true
# The minimal sampling interval for serving model monitor chars
min_chart_interval_sec: 40

View File

@@ -0,0 +1,54 @@
aws {
s3 {
# S3 credentials, used for read/write access by various SDK elements
# default, used for any bucket not specified below
key: ""
secret: ""
region: ""
use_credentials_chain: false
# Additional ExtraArgs passed to boto3 when uploading files. Can also be set per-bucket under "credentials".
extra_args: {}
credentials: [
# specifies key/secret credentials to use when handling s3 urls (read or write)
# {
# bucket: "my-bucket-name"
# key: "my-access-key"
# secret: "my-secret-key"
# },
{
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
host: "localhost:9000"
key: "minioadmin"
secret: "minioadmin"
# region: my-server
multipart: false
secure: false
}
]
}
}
google.storage {
# Default project and credentials file
# Will be used when no bucket configuration is found
// project: "clearml"
// credentials_json: "/path/to/credentials.json"
//
// # Specific credentials per bucket and sub directory
// credentials = [
// {
// bucket: "my-bucket"
// subdir: "path/in/bucket" # Not required
// project: "clearml"
// credentials_json: "/path/to/credentials.json"
// },
// ]
}
azure.storage {
# containers: [
# {
# account_name: "clearml"
# account_key: "secret"
# # container_name:
# }
# ]
}

View File

@@ -9,3 +9,20 @@ non_responsive_tasks_watchdog {
}
multi_task_histogram_limit: 100
hyperparam_values {
# max allowed outdate time for the cashed result
cache_allowed_outdate_sec: 60
# cache ttl sec
cache_ttl_sec: 86400
}
# the maximum amount of unique last metrics/variants combinations
# for which the last values are stored in a task
max_last_metrics: 2000
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
async_events_delete: true
# do not use async_delete if the deleted task has amount of events lower than this threshold
async_events_delete_threshold: 100000

View File

@@ -0,0 +1,5 @@
default_worker_timeout_sec: 600
default_cluster_timeout_sec: 600
# The minimal sampling interval for resource dashboard and worker activity charts
min_chart_interval_sec: 40

View File

@@ -2,6 +2,8 @@ from functools import lru_cache
from os import getenv
from pathlib import Path
from boltons.iterutils import first
from apiserver.config_repo import config
from apiserver.version import __version__
@@ -9,7 +11,9 @@ root = Path(__file__).parent.parent
def _get(prop_name, env_suffix=None, default=""):
value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}")
suffix = env_suffix or prop_name
keys = [f"{p}_SERVER_{suffix}" for p in ("CLEARML", "TRAINS")]
value = first(map(getenv, keys))
if value:
return value

View File

@@ -17,11 +17,21 @@ log = config.logger("database")
strict = config.get("apiserver.mongo.strict", True)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_MONGODB_SERVICE_HOST",
"TRAINS_MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
OVERRIDE_PORT_ENV_KEY = (
"CLEARML_MONGODB_SERVICE_PORT",
"TRAINS_MONGODB_SERVICE_PORT",
"MONGODB_SERVICE_PORT",
)
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
class DatabaseEntry(models.Base):
@@ -32,45 +42,74 @@ class DatabaseEntry(models.Base):
class DatabaseFactory:
_entries = []
@classmethod
def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry:
return DatabaseEntry(alias=alias, **settings)
@classmethod
def initialize(cls):
db_entries = config.get("hosts.mongo", {})
missing = []
log.info("Initializing database connections")
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
log.info(f"Using override mongodb port {override_port}")
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
if override_connection_string:
log.info(f"Using override mongodb connection string template {override_connection_string}")
else:
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
if override_port:
log.info(f"Using override mongodb port {override_port}")
if override_username:
log.info(f"Using override mongodb username {override_username}")
if override_password:
log.info(f"Using override mongodb password ******")
if override_query:
log.info(f"Using override mongodb query {override_query}")
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_connection_string:
con_str = furl(override_connection_string).add(path=key).url
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
entry.host = con_str
else:
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_username:
entry.host = furl(entry.host).set(username=override_username).url
if override_password:
entry.host = furl(entry.host).set(password=override_password).url
if override_query:
entry.host = furl(entry.host).set(query=override_query).url
try:
entry.validate()
log.info(
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
)
register_connection(alias=alias, host=entry.host)
register_connection(**entry.to_struct())
cls._entries.append(entry)
except ValidationError as ex:
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
if missing:
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
raise ValueError(
"Missing database configuration for %s" % ", ".join(missing)
)
@classmethod
def get_entries(cls):
@@ -91,7 +130,7 @@ class DatabaseFactory:
# reconnection from work so workaround this
# get_connection(entry.alias, reconnect=True)
disconnect(entry.alias)
register_connection(alias=entry.alias, host=entry.host)
register_connection(**entry.to_struct())
get_connection(entry.alias)

View File

@@ -5,7 +5,7 @@ from textwrap import shorten
import dpath
from dpath.exceptions import InvalidKeyName
from elasticsearch import ElasticsearchException
from elastic_transport import TransportError, ApiError
from elasticsearch.helpers import BulkIndexError
from jsonmodels.errors import ValidationError as JsonschemaValidationError
from mongoengine.errors import (
@@ -16,7 +16,7 @@ from mongoengine.errors import (
LookUpError,
InvalidQueryError,
)
from pymongo.errors import PyMongoError, NotMasterError
from pymongo.errors import PyMongoError, NotPrimaryError
from apiserver.apierrors import errors
@@ -166,7 +166,10 @@ class MongoEngineErrorsHandler(object):
@classmethod
@throws_default_error(errors.server_error.InternalError)
def invalid_query_error(cls, e, message, **_):
pass
if e.args:
inner = e.args[0]
if isinstance(inner, LookUpError):
cls.lookup_error(inner, message)
@contextmanager
@@ -195,7 +198,7 @@ def translate_errors_context(message=None, **kwargs):
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
except PyMongoError as e:
raise errors.server_error.InternalError(message, err=str(e))
except NotMasterError as e:
except NotPrimaryError as e:
raise errors.server_error.InternalError(message, err=str(e))
except MakeGetAllQueryError as e:
raise errors.bad_request.ValidationError(e.error, field=e.field)
@@ -207,9 +210,9 @@ def translate_errors_context(message=None, **kwargs):
raise errors.bad_request.ValidationError(e.args[0])
except BulkIndexError as e:
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
except ElasticsearchException as e:
except (TransportError, ApiError) as e:
raise errors.server_error.DataError(e, message, **kwargs)
except InvalidKeyName:
raise errors.server_error.DataError("invalid empty key encountered in data")
except Exception as ex:
except Exception:
raise

View File

@@ -176,6 +176,13 @@ class SafeMapField(MapField, DictValidationMixin):
self.error("Empty keys are not allowed in a MapField")
class NullableStringField(StringField):
def validate(self, value):
if value is None:
return
super(NullableStringField, self).validate(value)
class SafeDictField(DictField, DictValidationMixin):
def validate(self, value):
self._safe_validate(value)

View File

@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
class EntityVisibility(Enum):
active = "active"
archived = "archived"
hidden = "hidden"

View File

@@ -4,6 +4,7 @@ from mongoengine import (
EmbeddedDocumentListField,
EmailField,
DateTimeField,
BooleanField,
)
from apiserver.database import Database, strict
@@ -48,7 +49,9 @@ class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
label = StringField()
last_used = DateTimeField()
last_used_from = StringField()
class User(DbModelMixin, AuthDocument):
@@ -74,3 +77,6 @@ class User(DbModelMixin, AuthDocument):
email = EmailField(unique=True, sparse=True)
""" Email uniquely identifying the user """
autocreated = BooleanField(default=False)
""" Set to true if the user was auto created based on config settings"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
from typing import Sequence, Type
from mongoengine import EmbeddedDocument, StringField, Document
from pymongo import UpdateOne
from pymongo.collection import Collection
from apiserver.database.model.base import ProperDictMixin
class MetadataItem(EmbeddedDocument, ProperDictMixin):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
collection: Collection = cls._get_collection()
res = collection.update_one(
filter={"_id": _id},
update={
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
},
array_filters=[
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
],
upsert=False,
)
if len(items) == 1 and res.modified_count == 1:
return res.modified_count
requests = [
UpdateOne(
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
update={"$push": {"metadata": item}},
)
for item in items
]
res = collection.bulk_write(requests)
return 1 if res.modified_count else 0
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)

View File

@@ -1,17 +1,34 @@
from mongoengine import Document, StringField, DateTimeField, BooleanField
from mongoengine import (
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentField,
IntField,
ListField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.company import Company
from apiserver.database.model.project import Project
from apiserver.database.model.task.metrics import MetricEvent
from apiserver.database.model.task.task import Task
from apiserver.database.model.user import User
class Model(DbModelMixin, Document):
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -19,9 +36,19 @@ class Model(DbModelMixin, Document):
"parent",
"project",
"task",
("company", "framework"),
"last_update",
("company", "last_update"),
("company", "name"),
("company", "uri"),
# distinct queries support
("company", "tags"),
("company", "system_tags"),
("company", "project", "tags"),
("company", "project", "system_tags"),
("company", "user"),
("company", "project", "user"),
("company", "framework"),
("company", "project", "framework"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -50,14 +77,15 @@ class Model(DbModelMixin, Document):
"project",
"task",
"parent",
"metadata.*",
),
range_fields=("created", "last_metrics.*", "last_iteration"),
datetime_fields=("last_update", "last_change"),
)
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field="Model", required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
@@ -69,7 +97,19 @@ class Model(DbModelMixin, Document):
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
last_change = DateTimeField()
last_changed_by = StringField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)
last_iteration = IntField(default=0)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@@ -1,4 +1,4 @@
from mongoengine import StringField, DateTimeField, IntField
from mongoengine import StringField, DateTimeField, IntField, ListField
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
@@ -9,15 +9,19 @@ from apiserver.database.model.base import GetMixin
class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name", "basename", "description"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"parent",
"path",
("company", "name"),
("company", "basename"),
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
@@ -34,7 +38,8 @@ class Project(AttributedDocument):
min_length=3,
sparse=True,
)
description = StringField(required=True)
basename = StrippedStringField(required=True)
description = StringField()
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
@@ -44,3 +49,5 @@ class Project(AttributedDocument):
logo_url = StringField()
logo_blob = StringField(exclude_by_default=True)
company_origin = StringField(exclude_by_default=True)
parent = StringField(reference_field="Project")
path = ListField(StringField(required=True), exclude_by_default=True)

View File

@@ -4,34 +4,43 @@ from mongoengine import (
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import DbModelMixin, AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
""" Task ID """
added = DateTimeField(required=True)
''' Added to the queue '''
""" Added to the queue """
class Queue(DbModelMixin, Document):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
"db_alias": Database.backend,
"strict": strict,
}
id = StringField(primary_key=True)
@@ -40,7 +49,12 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
tags = SafeSortedListField(
StringField(required=True), default=list, user_set_allowed=True
)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -0,0 +1,76 @@
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
BooleanField,
)
from apiserver.database import Database, strict
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import ProperDictMixin
class AWSBucketSettings(EmbeddedDocument, ProperDictMixin):
bucket = StringField()
subdir = StringField()
host = StringField()
key = StringField()
secret = StringField()
token = StringField()
multipart = BooleanField()
acl = StringField()
secure = BooleanField()
region = StringField()
verify = BooleanField()
use_credentials_chain = BooleanField()
class AWSSettings(EmbeddedDocument, DbModelMixin):
key = StringField()
secret = StringField()
region = StringField()
token = StringField()
use_credentials_chain = BooleanField()
buckets = EmbeddedDocumentListField(AWSBucketSettings)
class GoogleBucketSettings(EmbeddedDocument, ProperDictMixin):
bucket = StringField()
subdir = StringField()
project = StringField()
credentials_json = StringField()
class GoogleStorageSettings(EmbeddedDocument, DbModelMixin):
project = StringField()
credentials_json = StringField()
buckets = EmbeddedDocumentListField(GoogleBucketSettings)
class AzureStorageContainerSettings(EmbeddedDocument, ProperDictMixin):
account_name = StringField(required=True)
account_key = StringField(required=True)
container_name = StringField()
class AzureStorageSettings(EmbeddedDocument, DbModelMixin):
containers = EmbeddedDocumentListField(AzureStorageContainerSettings)
class StorageSettings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"company"
],
}
id = StringField(primary_key=True)
company = StringField(required=True, unique=True)
last_update = DateTimeField()
aws: AWSSettings = EmbeddedDocumentField(AWSSettings)
google: GoogleStorageSettings = EmbeddedDocumentField(GoogleStorageSettings)
azure: AzureStorageSettings = EmbeddedDocumentField(AzureStorageSettings)

View File

@@ -4,6 +4,8 @@ from mongoengine import (
DynamicField,
LongField,
EmbeddedDocumentField,
IntField,
FloatField,
)
from apiserver.database.fields import SafeMapField
@@ -19,7 +21,13 @@ class MetricEvent(EmbeddedDocument):
variant = StringField(required=True)
value = DynamicField(required=True)
min_value = DynamicField() # for backwards compatibility reasons
min_value_iteration = IntField()
max_value = DynamicField() # for backwards compatibility reasons
max_value_iteration = IntField()
first_value = FloatField()
first_value_iteration = IntField()
count = IntField()
mean_value = FloatField()
class EventStats(EmbeddedDocument):

View File

@@ -11,6 +11,5 @@ class Result(object):
class Output(EmbeddedDocument):
destination = StrippedStringField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Sequence
from mongoengine import (
StringField,
@@ -17,6 +17,9 @@ from apiserver.database.fields import (
SafeDictField,
UnionField,
SafeSortedListField,
EmbeddedDocumentListField,
NullableStringField,
NoneType,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -79,13 +82,17 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
mode = StringField(
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
)
uri = StringField()
hash = StringField()
content_size = LongField()
timestamp = LongField()
type_data = EmbeddedDocumentField(ArtifactTypeData)
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
display_data = SafeSortedListField(
ListField(UnionField((int, float, str, NoneType)))
)
class ParamsItem(EmbeddedDocument, ProperDictMixin):
@@ -103,17 +110,37 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField()
class TaskModelTypes:
input = "input"
output = "output"
TaskModelNames = {
TaskModelTypes.input: "Input Model",
TaskModelTypes.output: "Output Model",
}
class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
model = StringField(required=True, reference_field="Model")
updated = DateTimeField()
class Models(EmbeddedDocument, ProperDictMixin):
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
docker_cmd = StringField()
queue = StringField()
queue = StringField(reference_field="Queue")
""" Queue ID where task was queued """
@@ -125,6 +152,7 @@ class TaskType(object):
application = "application"
monitor = "monitor"
controller = "controller"
report = "report"
optimizer = "optimizer"
service = "service"
qc = "qc"
@@ -135,12 +163,10 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
"execution.parameters.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
"hyperparams.": AttributedDocument._numeric_locale,
}
meta = {
@@ -153,21 +179,39 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"last_update",
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
("company", "system_tags", "last_update"),
("company", "last_update", "system_tags"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
# distinct queries support
("company", "tags"),
("company", "system_tags"),
("company", "project", "tags"),
("company", "project", "system_tags"),
("company", "user"),
("company", "project", "user"),
("company", "parent"),
("company", "project", "parent"),
("company", "type"),
("company", "project", "type"),
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
"$name",
"$id",
"$comment",
"$execution.model",
"$output.model",
"$report",
"$models.input.model",
"$models.output.model",
"$script.repository",
"$script.entry_point",
],
@@ -176,8 +220,9 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"execution.model": 2,
"output.model": 2,
"report": 10,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
"script.entry_point": 1,
},
@@ -185,9 +230,23 @@ class Task(AttributedDocument):
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
list_fields=(
"id",
"user",
"tags",
"system_tags",
"type",
"status",
"project",
"parent",
"hyperparams.*",
"execution.queue",
"models.input.model",
),
range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update", "last_change"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*",),
)
id = StringField(primary_key=True)
@@ -198,9 +257,11 @@ class Task(AttributedDocument):
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
report = StringField()
report_assets = ListField(StringField())
created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField()
completed = DateTimeField()
@@ -219,13 +280,19 @@ class Task(AttributedDocument):
last_change = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
duration = IntField() # obsolete, do not use
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
docker_init_script = StringField()
models: Models = EmbeddedDocumentField(Models, default=Models)
container = SafeMapField(field=NullableStringField())
enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True
)
last_changed_by = StringField()
def get_index_company(self) -> str:
"""

View File

@@ -0,0 +1,52 @@
from enum import Enum
from mongoengine import StringField, DateTimeField, IntField, EnumField
from apiserver.database import Database, strict
from apiserver.database.model import AttributedDocument
class StorageType(str, Enum):
fileserver = "fileserver"
s3 = "s3"
azure = "azure"
gs = "gs"
unknown = "unknown"
class FileType(str, Enum):
file = "file"
folder = "folder"
class DeletionStatus(str, Enum):
created = "created"
retrying = "retrying"
failed = "failed"
class UrlToDelete(AttributedDocument):
_field_collation_overrides = {
"url": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
("company", "user", "task"),
("company", "storage_type", "url"),
("status", "retry_count", "storage_type"),
],
}
id = StringField(primary_key=True)
url = StringField(required=True, unique_with="company")
task = StringField(required=True)
created = DateTimeField(required=True)
storage_type = EnumField(StorageType, default=StorageType.unknown)
type = EnumField(FileType, default=FileType.file)
retry_count = IntField(default=0)
last_failure_time = DateTimeField()
last_failure_reason = StringField()
status = EnumField(DeletionStatus, default=DeletionStatus.created)

Some files were not shown because too many files have changed in this diff Show More