Compare commits

259 Commits

Author SHA1 Message Date
clearml
d998b46cb2 Update image URLs 2025-06-04 12:01:32 +03:00
clearml
afcccffab3 Version bump to v2.1.0 2025-06-04 12:00:28 +03:00
clearml
cfe0b4fa55 API version bump to 2.32 2025-06-04 12:00:23 +03:00
clearml
0a02b7ad63 Small refactor 2025-06-04 11:59:39 +03:00
clearml
34a9f29465 Bump setuptools version to fix a known security issue 2025-06-04 11:58:09 +03:00
clearml
16d0955ae1 Remove example S3 credentials 2025-06-04 11:57:40 +03:00
clearml
c3f927d6c1 Make sure stack trace is not returned by default 2025-06-04 11:53:05 +03:00
clearml
06a7aa3126 Support skipping mongodb version check on startup 2025-06-04 11:52:34 +03:00
clearml
98122690df Fix empty object might be returned to the client for Google JSON credentials 2025-06-04 11:51:41 +03:00
clearml
f3c67ac3fd Support chart series per single resource in workers.get_stats 2025-06-04 11:51:00 +03:00
clearml
1983b22157 Report last_update field is now set when changing report name or tags 2025-06-04 11:49:48 +03:00
clearml
8a3fcacf5f Fixed requirements vulnerabilities 2025-06-04 11:49:15 +03:00
clearml
c0183e4302 Fix CSV export vulnerability by escaping cell text if it matches a macro or formula prefix 2025-06-04 11:44:24 +03:00
clearml
a7e340212f Add data_tool export improvements including 'company' flag, increased batch size for performance, date-time to log strings, more logs, an option to create a separate zip file per root project, an option to translate urls during tool export 2025-06-04 11:43:31 +03:00
clearml
bf00441146 Update gunicorn constraints due to CVE-2024-6827 2025-03-22 23:00:50 +02:00
pollfly
473aeb6ce9 Update README links (#275) 2025-03-02 12:07:10 +02:00
clearml
c3d305e0e2 Update GitHub repo 2025-01-16 12:04:43 +02:00
clearml
2976ce69cc Add docker-compose files for upgrade 2025-01-01 10:54:27 +02:00
clearml
8de039ee35 Version bump to v2.0.0 2024-12-31 22:09:37 +02:00
clearml
dbebdf2885 Fix project statistics calculation when both include and exclude child task filters present 2024-12-31 22:09:26 +02:00
clearml
e94d7fcfa9 Fix old events deletion for re-imported tasks 2024-12-31 22:08:59 +02:00
clearml
fa3727c5fc Upgrade MongoDB, ElasticSearch and Redis versions 2024-12-31 22:07:43 +02:00
clearml
b7795b3e2e Do not drop service request num to 0 on the chart for the periods when serving model was not reporting 2024-12-31 22:06:05 +02:00
clearml
8c29ebaece Add x_axis_label support in scalar iter charts 2024-12-31 22:05:30 +02:00
clearml
478f6b531b Fix crash when mongodb host is overridden with cluster settings 2024-12-31 22:04:10 +02:00
clearml
893ba48eda Add support for queue display name 2024-12-18 17:47:51 +02:00
clearml
1b76f36dcd Add support for multiple cookie domains 2024-12-18 17:47:21 +02:00
clearml
1299ebfcf3 Version bump to v1.17.0 2024-12-05 22:38:25 +02:00
clearml
8c4932c7eb Model files are now deleted from the fileserver on models.delete call 2024-12-05 22:38:06 +02:00
clearml
e48e64a82f Do not throw internal error on invalid file paths 2024-12-05 22:37:15 +02:00
clearml
046a142f36 Do not return the last incomplete interval for worker stats chart 2024-12-05 22:36:33 +02:00
clearml
207b9e4746 Allow all users to access storage APIs 2024-12-05 22:35:16 +02:00
clearml
605fccdef1 Update ElasticSearch version 2024-12-05 22:34:23 +02:00
clearml
8b8d8d6e6f Change model input_size field to string 2024-12-05 22:33:52 +02:00
clearml
97b9bbc4a9 Return created_in_version property in users.get_current_user 2024-12-05 22:32:28 +02:00
clearml
ed60a27d1a Add mem used charts and cpu/gpu counts to model endpoints instance details
For the num of requests serving charts always take the max value from the interval
2024-12-05 22:31:45 +02:00
clearml
17fcaba2cb Add internal script to fix fileserver URLs in mongodb 2024-12-05 22:30:03 +02:00
clearml
83dbf0fcb8 Add age_sec field to loading serving models
Return serving instance charts sorted by instance name
2024-12-05 22:27:52 +02:00
clearml
a3b303fa28 Add support for OneOfEmbeddedField 2024-12-05 22:27:20 +02:00
clearml
543c579a2e Do not allow creating a project that has a name or part of the path matching the existing public project 2024-12-05 22:26:57 +02:00
clearml
41b003f328 Add an error for trying to duplicate a public project 2024-12-05 22:26:05 +02:00
clearml
606bf2c4be Fix mongodb connection when overridden connection string contains connection options 2024-12-05 22:25:35 +02:00
clearml
57ce9446b1 Add _any_/_all_ queries support for datetime fields 2024-12-05 22:25:08 +02:00
clearml
073cc96fb8 Optimize tasks.move 2024-12-05 22:24:40 +02:00
clearml
77e7fb5c13 Add reference field to serving models 2024-12-05 22:24:18 +02:00
clearml
0b61ec2a56 Workers statistics now return 0s for the periods where the worker did not report 2024-12-05 22:23:52 +02:00
clearml
7506a13fe8 Quote all non numeric fields in csv files 2024-12-05 22:23:22 +02:00
clearml
9dfb4b882a Fix tasks/models.edit_tags do not update the task/model last_changed time 2024-12-05 22:22:49 +02:00
clearml
2eee909364 Export csv files fixed for projects containing semicolon in their names 2024-12-05 22:22:12 +02:00
clearml
3bcbc38c4c Add storage service support 2024-12-05 22:21:12 +02:00
clearml
eb755be001 Add model endpoints support 2024-12-05 22:20:11 +02:00
clearml
9997dcc977 Sync API version 2024-12-05 22:18:27 +02:00
clearml
ee9f45ea61 Optimize MongoDB indices usage for large dbs 2024-12-05 22:17:13 +02:00
clearml
a1956cdd83 When removing a task from a queue change the task state only if the task does not think that it is enqueued in some other place 2024-12-05 22:16:14 +02:00
clearml
4b93f1f508 Add queues.clear_queue
Add new parameter 'update_task_status' to queues.remove_task
2024-12-05 22:15:43 +02:00
clearml
2752c4df54 Fixed schema for users.get_current_user 2024-12-05 22:14:37 +02:00
clearml
2332b8589b Update the task execution queue in queues.add_task 2024-12-05 22:14:03 +02:00
clearml
f94cda4e9d Fix user migration 2024-12-05 19:13:55 +02:00
clearml
a84e1ec0d6 Update licenses 2024-12-05 19:13:49 +02:00
clearml
4223fe73d1 Single task/model delete waits for events deletion in order to mitigate too many ES open scrolls due to repeated calls 2024-12-05 19:13:06 +02:00
clearml
f9577f9faa add update_execution_queue parameter to tasks.enqueue 2024-12-05 19:12:26 +02:00
clearml
58b748ddf3 Merge pipeline parameters with original task hyperparameters 2024-12-05 19:11:36 +02:00
clearml
fa41e14625 Allow enqueueing enqueued tasks 2024-12-05 19:10:34 +02:00
clearml
4df5687ecd Do not replace S3 links in data_tool export by default 2024-12-05 19:09:21 +02:00
clearml
9a69c21504 Fix model update for a deleted task 2024-12-05 19:08:26 +02:00
clearml
39c36527e2 Make sure that a task retrieved from a queue is not in aborted status 2024-12-05 19:07:55 +02:00
clearml
f59ef65fa6 Update API version to v2.31 2024-12-05 19:07:34 +02:00
clearml
8f942f0da2 Data tool can now export project trees not starting from the root 2024-12-05 19:06:56 +02:00
clearml
7b5679fd70 Optimize events deletion in tasks.delete_many/reset_many and models.delete_many operations 2024-12-05 19:06:25 +02:00
clearml
5a5f02cead Fix user credentials reset on the apiserver restart 2024-12-05 19:05:45 +02:00
clearml
cfcad6300a Add created to the range fields for tasks 2024-12-05 19:05:29 +02:00
clearml
fd46f3c6f3 Display only one debug image per iteration/metric and variant 2024-12-05 19:03:36 +02:00
clearml
e86b7fd24e Support for first and mean value for task last scalar metrics 2024-12-05 19:02:48 +02:00
clearml
50593f69f8 Allow enqueueing failed tasks 2024-12-05 18:57:06 +02:00
clearml
ba928854e0 MongoDB upgrade to v5.0 2024-12-05 18:54:23 +02:00
allegroai
83a0485518 Fix user credentials reset on apiserver restart 2024-07-17 11:22:52 +03:00
allegroai
f3491cc9b9 Update README 2024-07-07 13:28:40 +03:00
allegroai
7558426bc6 Fix max upload size limit 2024-06-26 11:21:53 +03:00
allegroai
ce01e37c66 Refactor docker compose files: remove legacy, add services agent initialization in Linux 2024-06-26 10:53:43 +03:00
allegroai
92b42d66b7 Remove default credentials and reset existing credentials if none were provided 2024-06-26 10:52:42 +03:00
allegroai
f7d36bea4f Use an auth token in async_urls_delete when contacting the fileserver 2024-06-20 18:00:19 +03:00
allegroai
f1c876089b Add worker_pattern parameter to workers.get_all and get_count endpoints 2024-06-20 17:59:28 +03:00
allegroai
dd0ecb712d Added fileserver.upload.max_upload_size_mb setting 2024-06-20 17:58:33 +03:00
allegroai
fcfc1e8998 Support a more granular distributed lock wait 2024-06-20 17:57:54 +03:00
allegroai
9c210bb4fa Fix fixed users creation/removal 2024-06-20 17:57:23 +03:00
allegroai
14547155cb Delete pipeline steps in pipelines.delete_runs 2024-06-20 17:55:52 +03:00
allegroai
3f34f83a91 Version bump to 1.16.0
API version bump to 2.30
Add missing endpoints to schema
2024-06-20 17:55:17 +03:00
allegroai
da3941e6f2 Upgrade pymongo dependency 2024-06-20 17:53:15 +03:00
allegroai
2e19a18ee4 Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints 2024-06-20 17:52:46 +03:00
allegroai
cdc668e3c8 Fileserver authorization is enabled by default 2024-06-20 17:50:02 +03:00
allegroai
7c9889605a Add token authorization to fileserver 2024-06-20 17:48:54 +03:00
allegroai
5456ee4ebf Data tool export projects by name now includes subprojects + option for exporting all projects added 2024-06-20 17:48:18 +03:00
allegroai
562cb77003 Support getting and clearing task logs using specific metrics 2024-06-20 17:47:39 +03:00
allegroai
91df2bb3b7 Use better token generation for the secret key 2024-06-20 17:46:23 +03:00
allegroai
cb9812caee Do not return any mongodb instructions as a result of task update operations 2024-06-20 17:44:17 +03:00
allegroai
0496582d96 Ensure min interval on workers history charts so that we do not get "saw like" chart due to the missing points in the intervals 2024-06-20 17:43:52 +03:00
allegroai
beff19e104 Fix do not return full file path on errors from the fileserver 2024-06-20 17:43:19 +03:00
pollfly
639b3d59a4 Update docstrings (#246)
Edit description so they can be rendered using MDX
2024-06-20 17:00:31 +03:00
allegroai
c0d687e2ef Fix missing git in Dockerfile for building webapp 2024-03-28 17:50:35 +02:00
allegroai
9c95c63ce0 Version bump to v1.15.0 2024-03-24 11:25:05 +02:00
allegroai
73179f53c2 Use latest patch versions for ES and Mongo 2024-03-24 11:24:51 +02:00
allegroai
ddc8a76279 Set API version to v2.29 2024-03-18 16:02:45 +02:00
allegroai
ac7ea0d477 Allow filtering task models.input.model field by array of ids 2024-03-18 16:01:45 +02:00
allegroai
3544ed19f8 Use latest patch versions for mongodb and ES 2024-03-18 15:59:15 +02:00
allegroai
5e68f053a0 Add widgets link in nginx configuration 2024-03-18 15:58:19 +02:00
allegroai
7bd5fdad59 Update webserver build: allow using external configuration from a file or from environment variables 2024-03-18 15:57:19 +02:00
allegroai
484c72aa0c Upgrade to Debian bookworm 2024-03-18 15:56:18 +02:00
allegroai
2027afbed5 Added missing ES index template for scalar events 2024-03-18 15:54:38 +02:00
allegroai
7d649f1964 Support controlling config value inheritance from the base folder 2024-03-18 15:53:58 +02:00
allegroai
8d237b3cae Upgrade Redis to v6.2 2024-03-18 15:53:07 +02:00
allegroai
e8ee6ce72e Code cleanup 2024-03-18 15:52:22 +02:00
allegroai
5749ff0454 Add security headers to webserver 2024-03-18 15:50:40 +02:00
allegroai
5189adf4f1 Improve handling of fixed users 2024-03-18 15:49:42 +02:00
allegroai
92a4e56c1f Add support for cookies extensions 2024-03-18 15:46:07 +02:00
allegroai
33528870ae Request cookies processing enhanced for more flexibility 2024-03-18 15:45:09 +02:00
allegroai
85f5b8b6f6 Fix last metrics for task are updated for events reported without variants 2024-03-18 15:44:28 +02:00
allegroai
6112910768 Make sure that legacy templates are deleted and empty db check is done on the new templates 2024-03-18 15:40:13 +02:00
allegroai
d3013ac285 Invalidate token on user logoff 2024-03-18 15:38:44 +02:00
allegroai
88abf28287 Add ElasticSearch 8.x support 2024-03-18 15:37:44 +02:00
allegroai
6a1fc04d1e Set cookies SameSite value to Lax 2024-02-13 16:18:21 +02:00
allegroai
ee8eb03698 Fix crash when importing events for public company tasks 2024-02-13 16:17:52 +02:00
allegroai
5799baae45 Make sure that APIs that aggregate task/model data from projects can be called for the root project 2024-02-13 16:17:33 +02:00
allegroai
801e536c5e Fix tasks.started to correctly handle null values in the started field 2024-02-13 16:17:02 +02:00
allegroai
6e484ea8f4 Fix missing region parameter when deleting files from minio server 2024-02-13 16:16:24 +02:00
allegroai
a47e65d974 Add input parameters check to multiple APIs 2024-02-13 16:15:55 +02:00
allegroai
702b6dc9c8 Version bump to v1.14.0 2024-01-10 15:31:11 +02:00
allegroai
db15f235e4 Make sure files downloaded from the apiserver are not cached by browsers 2024-01-10 15:31:01 +02:00
allegroai
8c347f8fa9 Fix include and exclude filters not processing "no tags" condition 2024-01-10 15:26:55 +02:00
allegroai
768c3d80ff Remove callback_url_prefix and state parameters from login.supported_modes and does not return urls 2024-01-10 15:26:22 +02:00
allegroai
a5c3ef6385 Fix query filter so that the default operator between different query operations on the same parameter is AND instead of OR 2024-01-10 15:24:37 +02:00
allegroai
11b7a384af Set API version 2.28 2024-01-10 15:23:54 +02:00
allegroai
9a70ade4a6 Support task models with missing model field in data_tool import 2024-01-10 15:18:58 +02:00
allegroai
91ce140901 Add "queue watched" indication to pipelines.start_pipeline 2024-01-10 15:15:43 +02:00
allegroai
49084a9c49 Optimize task statistics for projects dashboard and statistics reporter 2024-01-10 15:13:25 +02:00
allegroai
8a99eb6812 Fix model_metrics parameter name in get_multi_task_metrics schema 2024-01-10 15:12:56 +02:00
allegroai
811ab2bf4f Support exporting users with data tool 2024-01-10 15:12:07 +02:00
allegroai
3752db122b Add events.get_multi_task_metrics 2024-01-10 15:11:27 +02:00
allegroai
439911b84c Upgrade werkzeug and flask dependencies 2024-01-10 15:10:46 +02:00
allegroai
262a301e28 Check for dictionary type for some model and task fields 2024-01-10 15:10:41 +02:00
allegroai
a604451b01 Refactor check for tasks write permission 2024-01-10 15:08:20 +02:00
allegroai
88a7773621 Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots 2024-01-10 15:07:46 +02:00
allegroai
35c4061992 Support filtering by task or model ids in projects.get_unique_metric_variants 2024-01-10 15:06:21 +02:00
allegroai
4684fd5b74 Version bump to v1.13.0 2023-11-17 09:49:26 +02:00
allegroai
e08123fcc0 Fix workers.activity_report should return 0s for the time when no workers reported 2023-11-17 09:49:18 +02:00
allegroai
e713e876eb Upgrade urllib3 requirement 2023-11-17 09:48:19 +02:00
allegroai
c2cc788319 Added supported API versions doc 2023-11-17 09:47:44 +02:00
allegroai
da8315d0db Allow queries on the list of execution queue ids in tasks.get_all/get_all_ex 2023-11-17 09:47:19 +02:00
allegroai
4ac6f88278 Optimize Workers retrieval
Store worker statistics under worker id and not internal redis key
Fix unit tests
2023-11-17 09:46:44 +02:00
allegroai
a7865ccbec Turn on async task events deletion in case there are more than 100_000 events 2023-11-17 09:45:55 +02:00
allegroai
ec14f327c6 Optimize endpoints that do not require authorization by not validating JWT token 2023-11-17 09:45:22 +02:00
allegroai
a03b24d6b6 Add log info on caller IP if token validation fails 2023-11-17 09:43:59 +02:00
allegroai
cb71ef8e47 Fix missing scroll_id in events.get_scalar_metric_data 2023-11-17 09:43:11 +02:00
allegroai
8678fbc995 Fix properly unset Task fields on task reset 2023-11-17 09:42:39 +02:00
allegroai
58df8f201a Update API to 2.27 2023-11-17 09:40:34 +02:00
allegroai
f4bf16c156 Fix schema for swagger compatibility 2023-11-17 09:39:52 +02:00
allegroai
942f996237 Fix async_delete cannot be configured using configuration files 2023-11-17 09:39:22 +02:00
allegroai
c1e7f8f9c1 Optimize deletion of projects with many tasks 2023-11-17 09:38:32 +02:00
allegroai
274c487b37 Add update_tags api to tasks and models 2023-11-17 09:37:25 +02:00
allegroai
cc0129a800 Add filters parameter for passing user defined list filters for all get_all_ex apis 2023-11-17 09:36:58 +02:00
allegroai
388dd1b01f Fix regression issue with archive tasks display 2023-11-17 09:35:55 +02:00
allegroai
d62ecb5e6e Add last_change and last_change_by DB Model 2023-11-17 09:35:22 +02:00
allegroai
6d507616b3 Add pattern parameter to projects.get_hyperparam_values 2023-11-17 09:34:13 +02:00
allegroai
d0252a6dd9 Make sure that hyperparam/configuration/metadata keys that are contain only empty space are rejected 2023-11-17 09:32:22 +02:00
allegroai
2263e7cc1e Fix regression with archive tasks display 2023-07-31 14:16:08 +03:00
allegroai
81b93e6811 Updated dependency - dnspython is a required dependency of pymongo as of pymongo v4.3 (https://pymongo.readthedocs.io/en/stable/changelog.html#changes-in-version-4-3-4-3-2) 2023-07-27 11:49:40 +03:00
allegroai
491e83d0f1 Version bump to v1.12.0 2023-07-26 18:56:04 +03:00
allegroai
f84cc0a2cb Remove 10 metrics limit in multi-task plot comparison 2023-07-26 18:55:49 +03:00
allegroai
6c5f966ed4 Add new_status field to tasks.dequeue and dequeue_many endpoints 2023-07-26 18:55:05 +03:00
allegroai
4eff657810 Fix debug images not returned for tasks in new db 2023-07-26 18:54:19 +03:00
allegroai
74acaa31df Add explicit refresh interval to ES mappings
Fix queue tests
2023-07-26 18:54:02 +03:00
allegroai
21ed8559bf Fix worker keys not returned in queues.get_all_ex 2023-07-26 18:51:20 +03:00
allegroai
3927604648 Add task names to events.get_single_value_metrics endpoint response 2023-07-26 18:50:53 +03:00
allegroai
f7dcbd96ec Fix deleting model events
Add delete_external_artifacts parameter to projects.delete endpoint
2023-07-26 18:49:54 +03:00
allegroai
5950b81f0b Fix child tasks count for top level pipeline and dataset projects 2023-07-26 18:49:12 +03:00
allegroai
1e51e2e221 Allow projection of more than 500 items 2023-07-26 18:46:58 +03:00
allegroai
4c98b87554 Fix issues with new dependencies 2023-07-26 18:46:28 +03:00
allegroai
c196043d2a Add max_download_items to users.get_current_user endpoint response 2023-07-26 18:45:42 +03:00
allegroai
752020c66a Update API version to 2.26 2023-07-26 18:44:20 +03:00
allegroai
6885d07462 Write UTF-8 BOM into csv download file 2023-07-26 18:43:38 +03:00
allegroai
00552da1b0 Requests context is not needed any more 2023-07-26 18:43:09 +03:00
allegroai
eebe2eeffc Update requirements 2023-07-26 18:42:26 +03:00
allegroai
bc2fe28bdd Add field_mappings to organizations download endpoints 2023-07-26 18:39:41 +03:00
allegroai
ed86750b24 Add scalar field type to jsonmodels 2023-07-26 18:39:06 +03:00
allegroai
6df69afb25 Support "__$or" condition on projects children filtering 2023-07-26 18:38:41 +03:00
allegroai
3f22423c3f Support paging in projects.get_model_metadata_values and get_hyperparam_values endpoints 2023-07-26 18:38:11 +03:00
allegroai
3ad636c468 Exported csv file name now contains the project name (including non-ascii names) 2023-07-26 18:37:20 +03:00
allegroai
5c80336aa9 Project delete and validate_delete now analyses and presents info for datasets and pipelines 2023-07-26 18:36:45 +03:00
allegroai
5cd59ea6e3 Fix csv export handling "," in fields 2023-07-26 18:35:31 +03:00
allegroai
5d3ba4fa73 Fix events.get_multitask_plots to retrieve last iterations per each task metric separately 2023-07-26 18:34:30 +03:00
allegroai
42556c8dbb Pipelines children query now looks for pipeline projects and not tasks 2023-07-26 18:33:35 +03:00
allegroai
dbe1c6f00f Allow configuring multi-plots batch size 2023-07-26 18:33:10 +03:00
allegroai
a17485b1bd Allow dequeueing a deleted task 2023-07-26 18:32:32 +03:00
allegroai
a2b9fed92d Make sure that scroll parameters are ignored when downloading tasks 2023-07-26 18:31:56 +03:00
allegroai
ff34da3c88 Add organization.download_for_get_all endpoint 2023-07-26 18:31:20 +03:00
allegroai
5239755066 Support include_subprojects flag in reports.get_all_ex endpoint 2023-07-26 18:30:34 +03:00
allegroai
8061dfedbb Fix NewListBucketsHelper backwards compatibility 2023-07-26 18:27:51 +03:00
allegroai
011164ce9b Support __$and condition for excluded terms in get_all_ex endpoints list filters 2023-07-26 18:26:49 +03:00
allegroai
8135cf5258 Add include_subprojects to tasks/models.get_all endpoints
Fix escaping metadata for tasks, models and queues
2023-07-26 18:24:49 +03:00
allegroai
a83a932e84 Add pipelines.delete_runs endpoint 2023-07-26 18:23:05 +03:00
allegroai
db021f2863 Add workers.get_count endpoint 2023-07-26 18:21:52 +03:00
allegroai
1b650b1689 Add projects.get_user_names endpoint 2023-07-26 18:21:16 +03:00
allegroai
14d18a7aba Remove obsolete duration field 2023-07-26 18:19:41 +03:00
Olivier Girardot
a7ed46979f Fix handling of the subpaths with nginx templating (#204)
Co-authored-by: ogirardot <olivier.girardot@malt.com>
2023-07-02 16:12:29 +03:00
allegroai
452f606889 Version bump to v1.11 2023-05-25 19:40:07 +03:00
allegroai
fc47ccbf09 Add default services agent user 2023-05-25 19:39:53 +03:00
allegroai
0206811342 Improve empty database check during startup 2023-05-25 19:39:17 +03:00
allegroai
a3ac1049a3 Update ClearML SDK dependency 2023-05-25 19:38:48 +03:00
allegroai
8488f63a3a Add fileserver URL prefixes for async deletion 2023-05-25 19:38:07 +03:00
allegroai
9206a7c57d Schedule external file URLs for deletion on models deletion 2023-05-25 19:36:28 +03:00
allegroai
0c37ced2a1 Fix model Id handling when deleting models for tasks 2023-05-25 19:35:18 +03:00
allegroai
b22f26129e Update requirements 2023-05-25 19:34:19 +03:00
allegroai
d8b998ebd8 Bump API version to 2.25 2023-05-25 19:33:37 +03:00
allegroai
741fa84b52 Fix projects own_tasks does not take task state filter into account 2023-05-25 19:32:52 +03:00
allegroai
d9579891c8 Return only reports from the .reports projects in reports.get_all_ex 2023-05-25 19:31:05 +03:00
allegroai
900414d0de Add option to echo ping payload 2023-05-25 19:30:13 +03:00
allegroai
5449b332d2 Support reports from the root project in reports.get_all_ex 2023-05-25 19:29:46 +03:00
allegroai
875f4b9536 Fix task dequeue will changes status for un-queued/running tasks 2023-05-25 19:28:49 +03:00
allegroai
95b8f22899 Add CLEARML_FILES_HOST to async_delete in windows 2023-05-25 19:27:40 +03:00
allegroai
4058fb9ce5 Migrate to python 3.9 bullseye docker images
Update Mongo driver version
2023-05-25 19:27:14 +03:00
allegroai
cf8e847ed3 Switch to new redis version 2023-05-25 19:22:39 +03:00
allegroai
755cc803d9 Add remove_from_all_queues parameter to tasks.dequeue/dequeue_many endpoints 2023-05-25 19:22:10 +03:00
allegroai
3729afe014 Optimize queues.get_next_task to retrieve required task fields only 2023-05-25 19:21:24 +03:00
allegroai
dff2ed34e8 Support receiving mixed events for both locked and unlocked tasks and models events.add_batch 2023-05-25 19:20:35 +03:00
allegroai
de9651d761 Allow mixing Model and task events in the same events batch 2023-05-25 19:19:45 +03:00
allegroai
818496236b Support filtering by children tags in projects.get_all_ex 2023-05-25 19:19:10 +03:00
allegroai
e99817b28b Task reports can now return single value metrics 2023-05-25 19:18:24 +03:00
allegroai
58465fbc17 Model events are fully supported 2023-05-25 19:17:40 +03:00
allegroai
2e4e060a82 Task move forward/backwards in queue is now atomic 2023-05-25 19:16:33 +03:00
allegroai
5c5d9b6434 Fix numeric hyperparam values are not sorted lexicographically with descending sort order 2023-05-25 19:15:59 +03:00
allegroai
4291ad682a Support filtering by task name in projects.get_task_parent 2023-05-25 19:15:26 +03:00
allegroai
4c22757002 Fix task that is not in queue but has 'queued' status can't be dequeued 2023-05-25 19:14:25 +03:00
allegroai
6e777e80b8 Cleaned up unit tests 2023-05-25 19:13:10 +03:00
allegroai
c8e4d9eeac Fix Dockerfile uses deprecated base image 2023-04-18 10:50:13 +03:00
dependabot[bot]
b51aa5c29b Bump redis from 3.5.3 to 4.4.4 in /apiserver (#190)
Bumps [redis](https://github.com/redis/redis-py) from 3.5.3 to 4.4.4.
- [Release notes](https://github.com/redis/redis-py/releases)
- [Changelog](https://github.com/redis/redis-py/blob/master/CHANGES)
- [Commits](https://github.com/redis/redis-py/compare/3.5.3...v4.4.4)

---
updated-dependencies:
- dependency-name: redis
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-01 08:59:02 +03:00
allegroai
e7c9daa42b Fix get_task_events to correctly use last_iters for model events 2023-03-28 16:45:44 +03:00
allegroai
7357654249 Version bump to v1.10 2023-03-23 19:17:00 +02:00
allegroai
a6f671b46a Fix typo 2023-03-23 19:16:38 +02:00
allegroai
17a8b440bd Fix only last event of each type is stored per model (all should be stored) 2023-03-23 19:16:30 +02:00
allegroai
eb2b9cbd9a Fix project count for datasets and pipelines 2023-03-23 19:15:42 +02:00
allegroai
797e503e67 Update ES version 2023-03-23 19:14:33 +02:00
allegroai
30cfdac8f2 Fix project preview completed_tasks_24h should not count tasks that are marked as failed or running 2023-03-23 19:13:52 +02:00
allegroai
24bb87aaee Turn on mongo sorting using disk usage by default for sorting in *.get_all* apis 2023-03-23 19:12:52 +02:00
allegroai
dd49ba180a Improve statistics on projects children 2023-03-23 19:11:45 +02:00
allegroai
bda903d0d8 Set API version to 2.24 2023-03-23 19:11:13 +02:00
allegroai
9739eb2d5a Add report_assets field to report tasks 2023-03-23 19:09:03 +02:00
allegroai
cfbb37238f Add default workers timeout to the server's configuration 2023-03-23 19:08:11 +02:00
allegroai
6664c6237e Support querying by children_type in projects.get_all_ex 2023-03-23 19:07:42 +02:00
allegroai
74200a24bd Add filtering on child projects in projects.get_all_ex 2023-03-23 19:06:49 +02:00
john-zielke-snkeos
2fb9288a6c Add env switch to disable nginx ipv6 bind (#165) 2023-03-13 16:05:43 +02:00
shyallegro
5d014d81af Fix #184 and update docker build to include widgets (#185) 2023-03-07 11:26:12 +02:00
allegroai
3a2675abe1 Version bump to v1.9.2 2023-01-24 16:11:21 +02:00
allegroai
f0d68b1ce9 Make sure model label values are integer 2023-01-24 16:11:12 +02:00
allegroai
15db9cdaef Allow updating comments on published reports 2023-01-24 14:40:32 +02:00
Mal Miller
a45d47f5d7 Fix default value of CLEARML_AGENT_UPDATE_VERSION for agent-services (#114) 2023-01-03 13:45:52 +02:00
allegroai
b1a50c1370 Version bump to v1.9.1 2023-01-03 12:16:07 +02:00
allegroai
22a2a02760 Allow renaming published reports 2023-01-03 12:15:44 +02:00
allegroai
ab798e4170 Allow updating tags on published reports 2023-01-03 12:15:02 +02:00
allegroai
f09ac672d2 Add pipeline test 2023-01-03 12:14:12 +02:00
allegroai
2149b76f63 Fix crash when starting pipeline 2023-01-03 12:13:48 +02:00
199 changed files with 12223 additions and 3476 deletions

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -6,43 +6,16 @@
</br>Experiment Manager, ML-Ops and Data-Management**
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)
[![Python versions](https://img.shields.io/badge/python-3.9-blue.svg)](https://img.shields.io/badge/python-3.9-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/clearml/clearml-server.svg)](https://img.shields.io/github/release-pre/clearml/clearml-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/clearml)](https://artifacthub.io/packages/search?repo=clearml)
</div>
---
<div align="center">
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
</div>
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
due to Elasticsearchs usage of the Java Security Manager.
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
(which in any case **does not permit access to data within the Elasticsearch cluster**),
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
---
## ClearML Server
#### *Formerly known as Trains Server*
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/clearml/clearml).
It allows multiple users to collaborate and manage their experiments.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
@@ -124,9 +97,9 @@ In order to set up the **ClearML** client to work with your **ClearML Server**:
it will be inferred from the http/s scheme.
After launching the **ClearML Server** and configuring the **ClearML** client to use the **ClearML Server**,
you can [use](https://github.com/allegroai/clearml) **ClearML** in your experiments and view them in your **ClearML Server** web server,
you can [use](https://github.com/clearml/clearml) **ClearML** in your experiments and view them in your **ClearML Server** web server,
for example http://localhost:8080.
For more information about the ClearML client, see [**ClearML**](https://github.com/allegroai/clearml).
For more information about the ClearML client, see [**ClearML**](https://github.com/clearml/clearml).
## ClearML-Agent Services <a name="services"></a>
@@ -143,7 +116,7 @@ increased data transparency)
ClearML-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
Every task launched by ClearML-Agent Services will be registered as a new node in the system,
providing tracking and transparency capabilities.
You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/allegroai/clearml-agent#clearml-agent-services-mode-)
You can also run the ClearML-Agent Services manually, see details in [ClearML-agent services mode](https://github.com/clearml/clearml-agent#clearml-agent-services-mode-)
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
@@ -166,7 +139,7 @@ To restart the **ClearML Server**, you must first stop the containers, and then
## Upgrading <a name="upgrade"></a>
**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/allegroai/trains-server/blob/master/docker/docker-compose.yml).
**ClearML Server** releases are also reflected in the [docker compose configuration file](https://github.com/clearml/clearml-server/blob/master/docker/docker-compose.yml).
We strongly encourage you to keep your **ClearML Server** up to date, by keeping up with the current release.
**Note**: The following upgrade instructions use the Linux OS as an example.
@@ -199,7 +172,7 @@ To upgrade your existing **ClearML Server** deployment:
1. Download the latest `docker-compose.yml` file.
```bash
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker/docker-compose.yml -o docker-compose.yml
curl https://raw.githubusercontent.com/clearml/clearml-server/master/docker/docker-compose.yml -o docker-compose.yml
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
@@ -227,7 +200,7 @@ To upgrade your existing **ClearML Server** deployment:
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
For feature requests or bug reports, please use [GitHub issues](https://github.com/clearml/clearml-server/issues).
Additionally, you can always find us at *clearml@allegro.ai*

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2019 allegro.ai, Inc.
Copyright © 2025 ClearML Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -27,7 +27,7 @@
24: ["not_public_object", "object is not public"]
# Auth / Login
75: ["invalid_access_key", "access key not found for user"]
75: ["invalid_access_key", "access key not found"]
# Tasks
100: ["task_error", "general task error"]
@@ -53,6 +53,9 @@
# Reports
150: ["operation_supported_on_reports_only", "passed task is not report"]
# Pipelines
160: ["cannot_remove_all_runs", "at least one pipeline run should be left"]
# Models
200: ["model_error", "general task error"]
201: ["invalid_model_id", "invalid model id"]
@@ -73,12 +76,15 @@
402: ["project_has_tasks", "project has associated tasks"]
403: ["project_not_found", "project not found"]
405: ["project_has_models", "project has associated models"]
406: ["project_has_datasets", "project has associated non-empty datasets"]
407: ["invalid_project_name", "invalid project name"]
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
413: ["project_has_pipelines", "project has associated pipelines with active controllers"]
414: ["public_project_exists", "Cannot create project. Public project with the same name already exists"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]
@@ -101,6 +107,11 @@
1004: ["worker_not_registered", "worker is not registered"]
1005: ["worker_stats_not_found", "worker stats not found"]
# Serving
1050: ["invalid_container_id", "invalid container id"]
1051: ["container_not_registered", "container is not registered"]
1052: ["no_containers_for_url", "no container instances found for service url"]
1104: ["invalid_scroll_id", "Invalid scroll id"]
}

View File

@@ -1,10 +1,11 @@
from enum import Enum
from typing import Union, Type, Iterable
from numbers import Number
from typing import Union, Type, Iterable, Mapping
import jsonmodels.errors
import six
from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.fields import _LazyType, NotSet, EmbeddedField
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from mongoengine.base import BaseDocument
@@ -40,6 +41,34 @@ def make_default(field_cls, default_value):
return _FieldWithDefault
class OneOfEmbeddedField(EmbeddedField):
def __init__(
self,
*args,
discriminator_property: str,
discriminator_mapping: Mapping[str, type],
**kwargs,
):
self.discriminator_property = discriminator_property
self.discriminator_mapping = discriminator_mapping
model_types = tuple(set(self.discriminator_mapping.values()))
super().__init__(model_types, *args, **kwargs)
def parse_value(self, value):
"""Parse value to proper model type."""
if not isinstance(value, dict) or self.discriminator_property not in value:
return super().parse_value(value)
property_value = value.get(self.discriminator_property)
embed_type = self.discriminator_mapping.get(property_value)
if not embed_type:
raise jsonmodels.errors.ValidationError(
f"Could not find type matching discriminator property value: {property_value}"
)
return embed_type(**value)
class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default):
@@ -61,6 +90,22 @@ class ListField(fields.ListField):
item.validate()
class ScalarField(fields.BaseField):
"""String field."""
types = (str, int, float, bool)
class SafeStringField(fields.StringField):
"""String field that can also accept numbers as input"""
def parse_value(self, value):
if isinstance(value, Number):
value = str(value)
return super().parse_value(value)
class DictField(fields.BaseField):
types = (dict,)
@@ -108,9 +153,7 @@ class DictField(fields.BaseField):
if len(self.value_types) != 1:
tpl = 'Cannot decide which type to choose from "{types}".'
raise jsonmodels.errors.ValidationError(
tpl.format(
types=', '.join([t.__name__ for t in self.value_types])
)
tpl.format(types=", ".join([t.__name__ for t in self.value_types]))
)
return self.value_types[0](**value)
@@ -172,7 +215,7 @@ class EnumField(fields.StringField):
*args,
required=False,
default=None,
**kwargs
**kwargs,
):
choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator
@@ -195,7 +238,7 @@ class ActualEnumField(fields.StringField):
validators=None,
required=False,
default=None,
**kwargs
**kwargs,
):
self.__enum = enum_class
self.types = (enum_class,)
@@ -208,7 +251,7 @@ class ActualEnumField(fields.StringField):
*args,
required=required,
validators=validators,
**kwargs
**kwargs,
)
def parse_value(self, value):

View File

@@ -13,6 +13,14 @@ from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum
class TaskRequest(Base):
task: str = StringField(required=True)
class ModelRequest(Base):
model: str = StringField(required=True)
class HistogramRequestBase(Base):
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
model_events: bool = BoolField(default=False)
class GetMetricsAndVariantsRequest(Base):
task: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str,
@@ -41,6 +54,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
)
],
)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
@@ -50,6 +64,12 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class LegacyMetricEventsRequest(TaskRequest):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
@@ -58,7 +78,14 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
model_events: bool = BoolField(default=False)
class VectorMetricsIterHistogramRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class GetVariantSampleRequest(Base):
@@ -109,11 +136,17 @@ class TaskEventsRequest(TaskEventsRequestBase):
model_events: bool = BoolField(default=False)
class LegacyLogEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
scroll_id: str = StringField()
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
@@ -148,13 +181,30 @@ class MultiTasksRequestBase(Base):
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
@@ -164,6 +214,14 @@ class TaskPlotsRequest(Base):
model_events: bool = BoolField(default=False)
class GetScalarMetricDataRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()
@@ -172,3 +230,5 @@ class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)
exclude_metrics = ListField(items_types=[str])
include_metrics = ListField(items_types=[str])

View File

@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
pass
# state = StringField(help_text="ASCII base64 encoded application state")
# callback_url_prefix = StringField()
class BasicGuestMode(Base):

View File

@@ -42,12 +42,29 @@ class ModelRequest(models.Base):
model = fields.StringField(required=True)
class TaskRequest(models.Base):
task = fields.StringField(required=True)
class UpdateForTaskRequest(TaskRequest):
uri = fields.StringField()
iteration = fields.IntField()
override_model_id = fields.StringField()
class UpdateModelRequest(ModelRequest):
task = fields.StringField()
iteration = fields.IntField()
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class PublishModelRequest(ModelRequest):

View File

@@ -1,6 +1,11 @@
from jsonmodels import fields, models
from enum import auto
from typing import Sequence
from apiserver.apimodels import DictField
from jsonmodels import fields, models
from jsonmodels.validators import Length
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
from apiserver.utilities.stringenum import StringEnum
class Filter(models.Base):
@@ -23,3 +28,35 @@ class EntitiesCountRequest(models.Base):
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
class EntityType(StringEnum):
task = auto()
model = auto()
class ValueMapping(models.Base):
key = ScalarField(nullable=True)
value = ScalarField(nullable=True)
class FieldMapping(models.Base):
field = fields.StringField(required=True)
name = fields.StringField()
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
class PrepareDownloadForGetAllRequest(models.Base):
entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False)
only_fields = fields.ListField(
items_types=[str], validators=[Length(1)], required=True
)
field_mappings: Sequence[FieldMapping] = fields.ListField(
items_types=[FieldMapping], validators=[Length(1)], required=True
)
class DownloadForGetAllRequest(models.Base):
prepare_id = fields.StringField(required=True)

View File

@@ -1,4 +1,5 @@
from jsonmodels import models, fields
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
@@ -8,12 +9,13 @@ class Arg(models.Base):
value = fields.StringField(required=True)
class DeleteRunsRequest(models.Base):
project = fields.StringField(required=True)
ids = ListField([str], required=True, validators=[Length(1)])
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)
verify_watched_queue = fields.BoolField(default=False)

View File

@@ -1,8 +1,11 @@
from enum import Enum, auto
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
from apiserver.utilities.stringenum import StringEnum
class ProjectRequest(models.Base):
@@ -20,6 +23,7 @@ class MoveRequest(ProjectRequest):
class DeleteRequest(ProjectRequest):
force = fields.BoolField(default=False)
delete_contents = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class ProjectOrNoneRequest(models.Base):
@@ -27,6 +31,11 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True)
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
ids = fields.ListField(str)
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -37,23 +46,44 @@ class ProjectTagsRequest(TagsRequest):
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
projects = fields.ListField(items_types=[str, type(None)])
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest):
tasks_state = ActualEnumField(EntityVisibility)
task_name = fields.StringField()
class ProjectHyperparamValuesRequest(MultiProjectRequest):
class EntityTypeEnum(StringEnum):
task = auto()
model = auto()
class ProjectUserNamesRequest(MultiProjectRequest):
entity = ActualEnumField(EntityTypeEnum, default=EntityTypeEnum.task)
class MultiProjectPagedRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
pattern = fields.StringField()
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectChildrenType(Enum):
pipeline = "pipeline"
report = "report"
dataset = "dataset"
class ProjectsGetRequest(models.Base):
@@ -68,3 +98,6 @@ class ProjectsGetRequest(models.Base):
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)
children_tags_filter = DictField()

View File

@@ -17,6 +17,7 @@ class GetDefaultResp(Base):
class CreateRequest(Base):
name = StringField(required=True)
display_name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
@@ -47,6 +48,7 @@ class DeleteRequest(QueueRequest):
class UpdateRequest(QueueRequest):
name = StringField()
display_name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = DictField(value_types=[MetadataItem])
@@ -56,6 +58,14 @@ class TaskRequest(QueueRequest):
task = StringField(required=True)
class RemoveTaskRequest(TaskRequest):
update_task_status = BoolField(default=False)
class AddTaskRequest(TaskRequest):
update_execution_queue = BoolField(default=True)
class MoveTaskRequest(TaskRequest):
count = IntField(default=1)

View File

@@ -14,6 +14,7 @@ class UpdateReportRequest(Base):
tags = ListField(items_types=[str])
comment = StringField()
report = StringField()
report_assets = ListField(items_types=[str])
class CreateReportRequest(Base):
@@ -22,6 +23,7 @@ class CreateReportRequest(Base):
comment = StringField()
report = StringField()
project = StringField()
report_assets = ListField(items_types=[str])
class PublishReportRequest(Base):
@@ -55,15 +57,27 @@ class EventsRequest(Base):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class PlotEventsRequest(EventsRequest):
last_iters_per_task_metric: bool = BoolField(default=True)
class ScalarMetricsIterHistogram(HistogramRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class SingleValueMetrics(Base):
pass
class GetTasksDataRequest(Base):
debug_images: EventsRequest = EmbeddedField(EventsRequest)
plots: EventsRequest = EmbeddedField(EventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
ScalarMetricsIterHistogram
)
single_value_metrics: SingleValueMetrics = EmbeddedField(SingleValueMetrics)
allow_public = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetAllRequest(Base):

View File

@@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class GetConfigRequest(Base):
path = StringField()
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()

View File

@@ -0,0 +1,104 @@
from enum import Enum
from typing import Sequence
from jsonmodels.models import Base
from jsonmodels.fields import (
StringField,
EmbeddedField,
DateTimeField,
IntField,
FloatField,
BoolField,
)
from jsonmodels import validators
from jsonmodels.validators import Min
from apiserver.apimodels import ListField, JsonSerializableMixin, SafeStringField
from apiserver.apimodels import ActualEnumField
from apiserver.config_repo import config
from .workers import MachineStats
class ReferenceItem(Base):
type = StringField(
required=True,
validators=validators.Enum("app_id", "app_instance", "model", "task", "url"),
)
value = StringField(required=True)
class ServingModel(Base):
container_id = StringField(required=True)
endpoint_name = StringField(required=True)
endpoint_url = StringField() # can be not existing yet at registration time
model_name = StringField(required=True)
model_source = StringField()
model_version = StringField()
preprocess_artifact = StringField()
input_type = StringField()
input_size = SafeStringField()
tags = ListField(str)
system_tags = ListField(str)
reference: Sequence[ReferenceItem] = ListField(ReferenceItem)
class RegisterRequest(ServingModel):
timeout = IntField(
default=int(
config.get("services.serving.default_container_timeout_sec", 10 * 60)
),
validators=[Min(1)],
)
""" registration timeout in seconds (default is 10min) """
class UnregisterRequest(Base):
container_id = StringField(required=True)
class StatusReportRequest(ServingModel):
uptime_sec = IntField()
requests_num = IntField()
requests_min = FloatField()
latency_ms = IntField()
machine_stats: MachineStats = EmbeddedField(MachineStats)
class ServingContainerEntry(StatusReportRequest, JsonSerializableMixin):
key = StringField(required=True)
company_id = StringField(required=True)
ip = StringField()
register_time = DateTimeField(required=True)
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
class GetEndpointDetailsRequest(Base):
endpoint_url = StringField(required=True)
class MetricType(Enum):
requests = "requests"
requests_min = "requests_min"
latency_ms = "latency_ms"
cpu_count = "cpu_count"
gpu_count = "gpu_count"
cpu_util = "cpu_util"
gpu_util = "gpu_util"
ram_total = "ram_total"
ram_used = "ram_used"
ram_free = "ram_free"
gpu_ram_total = "gpu_ram_total"
gpu_ram_used = "gpu_ram_used"
gpu_ram_free = "gpu_ram_free"
network_rx = "network_rx"
network_tx = "network_tx"
class GetEndpointMetricsHistoryRequest(Base):
from_date = FloatField(required=True, validators=Min(0))
to_date = FloatField(required=True, validators=Min(0))
interval = IntField(required=True, validators=Min(1))
endpoint_url = StringField(required=True)
metric_type = ActualEnumField(MetricType, default=MetricType.requests)
instance_charts = BoolField(default=True)

View File

@@ -0,0 +1,60 @@
from jsonmodels.fields import StringField, BoolField, ListField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Enum
class AWSBucketSettings(Base):
bucket = StringField()
subdir = StringField()
host = StringField()
key = StringField()
secret = StringField()
token = StringField()
multipart = BoolField(default=True)
acl = StringField()
secure = BoolField(default=True)
region = StringField()
verify = BoolField(default=True)
use_credentials_chain = BoolField(default=False)
class AWSSettings(Base):
key = StringField()
secret = StringField()
region = StringField()
token = StringField()
use_credentials_chain = BoolField(default=False)
buckets = ListField(items_types=[AWSBucketSettings])
class GoogleBucketSettings(Base):
bucket = StringField()
subdir = StringField()
project = StringField()
credentials_json = StringField()
class GoogleSettings(Base):
project = StringField()
credentials_json = StringField()
buckets = ListField(items_types=[GoogleBucketSettings])
class AzureContainerSettings(Base):
account_name = StringField()
account_key = StringField()
container_name = StringField()
class AzureSettings(Base):
containers = ListField(items_types=[AzureContainerSettings])
class SetSettingsRequest(Base):
aws = EmbeddedField(AWSSettings)
google = EmbeddedField(GoogleSettings)
azure = EmbeddedField(AzureSettings)
class ResetSettingsRequest(Base):
keys = ListField([str], item_validators=[Enum("aws", "google", "azure")])

View File

@@ -96,10 +96,20 @@ class UpdateRequest(TaskUpdateRequest):
status_message = StringField(default="")
class DequeueRequest(UpdateRequest):
remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class StopRequest(UpdateRequest):
include_pipeline_steps = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
update_execution_queue = BoolField(default=True)
class DeleteRequest(UpdateRequest):
@@ -107,6 +117,7 @@ class DeleteRequest(UpdateRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class SetRequirementsRequest(TaskRequest):
@@ -259,6 +270,7 @@ class DeleteConfigurationRequest(TaskUpdateRequest):
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
include_pipeline_steps = BoolField(default=False)
class ArchiveResponse(models.Base):
@@ -270,8 +282,22 @@ class TaskBatchRequest(BatchRequest):
status_message = StringField(default="")
class ArchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class UnarchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
include_pipeline_steps = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest):
remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class EnqueueManyRequest(TaskBatchRequest):
@@ -287,6 +313,7 @@ class DeleteManyRequest(TaskBatchRequest):
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):
@@ -323,3 +350,8 @@ class DeleteModelsRequest(TaskRequest):
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)
class UpdateTagsRequest(BatchRequest):
add_tags = ListField([str])
remove_tags = ListField([str])

View File

@@ -4,6 +4,10 @@ from jsonmodels.models import Base
from apiserver.apimodels import DictField
class UserRequest(Base):
user = StringField(required=True)
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)

View File

@@ -12,9 +12,8 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin, ActualEnumField
from apiserver.config_repo import config
class WorkerRequest(Base):
@@ -24,9 +23,10 @@ class WorkerRequest(Base):
class RegisterRequest(WorkerRequest):
timeout = make_default(
IntField, DEFAULT_TIMEOUT
)() # registration timeout in seconds (default is 10min)
timeout = IntField(
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
)
""" registration timeout in seconds (default is 10min) """
queues = ListField(six.string_types) # list of queues this worker listens to
@@ -86,6 +86,7 @@ class CurrentTaskEntry(IdNameEntry):
class QueueEntry(IdNameEntry):
display_name = StringField()
next_task = EmbeddedField(IdNameEntry)
num_tasks = IntField()
@@ -100,12 +101,17 @@ class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
worker_pattern = StringField()
class GetAllResponse(Base):
workers = ListField(WorkerResponseEntry)
class GetCountRequest(GetAllRequest):
last_seen = IntField(default=0)
class StatsBase(Base):
worker_ids = ListField(str)
@@ -124,7 +130,7 @@ class AggregationType(Enum):
class StatItem(Base):
key = StringField(required=True)
aggregation = EnumField(AggregationType, default=AggregationType.avg)
aggregation = ActualEnumField(AggregationType, default=AggregationType.avg)
class GetStatsRequest(StatsReportBase):
@@ -132,17 +138,24 @@ class GetStatsRequest(StatsReportBase):
StatItem, required=True, validators=validators.Length(minimum_value=1)
)
split_by_variant = BoolField(default=False)
split_by_resource = BoolField(default=False)
class MetricResourceSeries(Base):
name = StringField()
values = ListField(float)
class AggregationStats(Base):
aggregation = EnumField(AggregationType)
dates = ListField(int)
values = ListField(float)
resource_series = ListField(MetricResourceSeries)
class MetricStats(Base):
metric = StringField()
variant = StringField()
dates = ListField(int)
stats = ListField(AggregationStats)

View File

@@ -5,7 +5,6 @@ import zlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import elasticsearch
@@ -24,12 +23,15 @@ from apiserver.bll.event.event_common import (
get_metric_variants_condition,
uncompress_plot,
get_max_metric_and_variant_counts,
PlotFields,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.task.utils import get_many_tasks_for_writing
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
@@ -39,32 +41,28 @@ from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.model.task.task import TaskStatus
from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
EVENT_TYPES: Set[str] = set(et.value for et in EventType if et != EventType.all)
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
MAX_LONG = 2**63 - 1
MIN_LONG = -(2**63)
log = config.logger(__file__)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
async_delete_threshold = config.get(
"services.tasks.async_events_delete_threshold", 100_000
)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
event_id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
@@ -102,47 +100,96 @@ class EventBLL(object):
return self._metrics
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
def _get_valid_entities(
company_id, ids: Mapping[str, bool], identity: Identity, model=False
) -> Set:
"""Verify that task or model exists and can be updated"""
if not ids:
return set()
with translate_errors_context():
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
allow_locked = {id_ for id_, allowed in ids.items() if allowed}
not_locked = {id_ for id_, allowed in ids.items() if not allowed}
res = set()
allow_locked_q = Q()
not_locked_q = (
Q(ready__ne=True) if model else Q(status__nin=LOCKED_TASK_STATUSES)
)
for requested_ids, locked_q in (
(allow_locked, allow_locked_q),
(not_locked, not_locked_q),
):
if not requested_ids:
continue
@staticmethod
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
"""Verify that task exists and can be updated"""
if not model_ids:
return set()
query = Q(id__in=requested_ids) & locked_q
if model:
ids = Model.objects(query & Q(company=company_id)).scalar("id")
else:
ids = {
t.id
for t in get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=query,
only=("id",),
throw_on_forbidden=False,
)
}
with translate_errors_context():
query = Q(id__in=model_ids, company=company_id)
if not allow_locked_models:
query &= Q(ready__ne=True)
res = Model.objects(query).only("id")
return {r.id for r in res}
res.update(ids)
return res
def add_events(
self, company_id, events, worker, allow_locked=False
self,
company_id: str,
identity: Identity,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False)
user_id = identity.user
task_ids = {}
model_ids = {}
for event in events:
if event.get("model_event", model_events) != model_events:
if event.get("model_event", False):
model = event.pop("model", None)
if model is not None:
event["task"] = model
entity_ids = model_ids
else:
event["model_event"] = False
entity_ids = task_ids
id_ = event.get("task")
allow_locked = event.pop("allow_locked", False)
if not id_:
continue
allowed_for_entity = entity_ids.get(id_)
if allowed_for_entity is None:
entity_ids[id_] = allow_locked
elif allowed_for_entity != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events"
)
if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
f"Inconsistent allow_locked setting in the passed events for {id_}"
)
found_in_both = set(task_ids).intersection(set(model_ids))
if found_in_both:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_entities(
company_id, ids=model_ids, identity=identity, model=True
)
valid_tasks = self._get_valid_entities(
company_id, ids=task_ids, identity=identity
)
actions: List[dict] = []
task_or_model_ids = set()
used_task_ids = set()
used_model_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
3, dict
@@ -152,30 +199,10 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked,
)
entity_name = "task"
for event in events:
x_axis_label = event.pop("x_axis_label", None)
# remove spaces from event type
event_type = event.get("type")
if event_type is None:
@@ -187,7 +214,8 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
if model_events and event_type == EventType.task_log.value:
model_event = event["model_event"]
if model_event and event_type == EventType.task_log.value:
errors_per_type[f"Task log events are not supported for models"] += 1
continue
@@ -196,8 +224,12 @@ class EventBLL(object):
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_or_model_id not in valid_entities:
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
if (model_event and task_or_model_id not in valid_models) or (
not model_event and task_or_model_id not in valid_tasks
):
errors_per_type[
f"Invalid {'model' if model_event else 'task'} id {task_or_model_id}"
] += 1
continue
event["type"] = event_type
@@ -219,13 +251,10 @@ class EventBLL(object):
# force iter to be a long int
iter = event.get("iter")
if iter is not None:
if model_events:
iter = 0
else:
iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1
continue
iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1
continue
event["iter"] = iter
# used to have "values" to indicate array. no need anymore
@@ -235,7 +264,6 @@ class EventBLL(object):
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type)
es_action = {
@@ -244,31 +272,34 @@ class EventBLL(object):
"_source": event,
}
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
# for "log" events, don't assign custom _id - whatever is sent, is written (not overwritten)
if event_type != EventType.task_log.value:
es_action["_id"] = self._get_event_id(event)
else:
es_action["_id"] = dbutils.id()
task_or_model_ids.add(task_or_model_id)
if (
iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_or_model_id] = max(
iter, task_iteration[task_or_model_id]
)
if not model_events:
if model_event:
used_model_ids.add(task_or_model_id)
else:
used_task_ids.add(task_or_model_id)
self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event,
last_events=task_last_events[task_or_model_id],
event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
x_axis_label=x_axis_label,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
)
actions.append(es_action)
@@ -291,6 +322,7 @@ class EventBLL(object):
if actions:
chunk_size = 500
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
# noinspection PyTypeChecker
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
@@ -306,38 +338,47 @@ class EventBLL(object):
else:
errors_per_type["Error when indexing events batch"] += 1
if not model_events:
remaining_tasks = set()
now = datetime.utcnow()
for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_or_model_id,
now=now,
iter_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
last_events=task_last_events.get(task_or_model_id),
)
now = datetime.utcnow()
for model_id in used_model_ids:
ModelBLL.update_statistics(
company_id=company_id,
user_id=user_id,
model_id=model_id,
last_update=now,
last_iteration_max=task_iteration.get(model_id),
last_scalar_events=task_last_scalar_events.get(model_id),
)
remaining_tasks = set()
for task_id in used_task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
user_id=user_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if not updated:
remaining_tasks.add(task_or_model_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks,
company_id=company_id,
user_id=user_id,
last_update=now,
)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
[{"_index": invalid_iteration_error}],
)
if not added:
@@ -393,7 +434,7 @@ class EventBLL(object):
return False
return True
def _update_last_scalar_events_for_task(self, last_events, event):
def _update_last_scalar_events_for_task(self, last_events, event, x_axis_label=None):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
@@ -401,47 +442,47 @@ class EventBLL(object):
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
variant = event.get("variant")
if not (metric and variant):
value = event.get("value")
if value is None:
return
metric = event.get("metric") or ""
variant = event.get("variant") or ""
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_event = last_events[metric_hash][variant_hash]
last_event["metric"] = metric
last_event["variant"] = variant
last_event["count"] = last_event.get("count", 0) + 1
last_event["total"] = last_event.get("total", 0) + value
event_iter = event.get("iter", 0)
event_timestamp = event.get("timestamp", 0)
value = event.get("value")
if value is not None and (
(event_iter, event_timestamp)
>= (
last_event.get("iter", event_iter),
last_event.get("timestamp", event_timestamp),
)
if (event_iter, event_timestamp) >= (
last_event.get("iter", event_iter),
last_event.get("timestamp", event_timestamp),
):
event_data = {
k: event[k]
for k in ("value", "metric", "variant", "iter", "timestamp")
if k in event
}
last_event_min_value = last_event.get("min_value", value)
last_event_min_value_iter = last_event.get("min_value_iter", event_iter)
if value < last_event_min_value:
event_data["min_value"] = value
event_data["min_value_iter"] = event_iter
else:
event_data["min_value"] = last_event_min_value
event_data["min_value_iter"] = last_event_min_value_iter
last_event_max_value = last_event.get("max_value", value)
last_event_max_value_iter = last_event.get("max_value_iter", event_iter)
if value > last_event_max_value:
event_data["max_value"] = value
event_data["max_value_iter"] = event_iter
else:
event_data["max_value"] = last_event_max_value
event_data["max_value_iter"] = last_event_max_value_iter
last_events[metric_hash][variant_hash] = event_data
last_event["value"] = value
last_event["iter"] = event_iter
last_event["timestamp"] = event_timestamp
if x_axis_label is not None:
last_event["x_axis_label"] = x_axis_label
first_value_iter = last_event.get("first_value_iter")
if first_value_iter is None or event_iter < first_value_iter:
last_event["first_value"] = value
last_event["first_value_iter"] = event_iter
last_event_min_value = last_event.get("min_value")
if last_event_min_value is None or value < last_event_min_value:
last_event["min_value"] = value
last_event["min_value_iter"] = event_iter
last_event_max_value = last_event.get("max_value")
if last_event_max_value is None or value > last_event_max_value:
last_event["max_value"] = value
last_event["max_value_iter"] = event_iter
def _update_last_metric_events_for_task(self, last_events, event):
"""
@@ -449,9 +490,9 @@ class EventBLL(object):
recent than the currently stored event for its metric/event_type combination.
last_events contains [metric_name -> event_type -> event]
"""
metric = event.get("metric")
metric = event.get("metric") or ""
event_type = event.get("type")
if not (metric and event_type):
if not event_type:
return
timestamp = last_events[metric][event_type].get("timestamp", None)
@@ -460,9 +501,10 @@ class EventBLL(object):
def _update_task(
self,
company_id,
task_id,
now,
company_id: str,
user_id: str,
task_id: str,
now: datetime,
iter_max=None,
last_scalar_events=None,
last_events=None,
@@ -478,8 +520,9 @@ class EventBLL(object):
return False
return TaskBLL.update_statistics(
task_id,
company_id,
task_id=task_id,
company_id=company_id,
user_id=user_id,
last_update=now,
last_iteration_max=iter_max,
last_scalar_events=last_scalar_events,
@@ -487,7 +530,9 @@ class EventBLL(object):
)
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.id_fields if field in event)
id_values = (
str(event[field]) for field in self.event_id_fields if field in event
)
return hashlib.md5("-".join(id_values).encode()).hexdigest()
def scroll_task_events(
@@ -559,11 +604,10 @@ class EventBLL(object):
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type,
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // last_iterations_per_plot)
@@ -589,7 +633,7 @@ class EventBLL(object):
"events": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": last_iterations_per_plot
"size": last_iterations_per_plot,
}
}
},
@@ -600,11 +644,7 @@ class EventBLL(object):
}
with translate_errors_context():
es_response = search_company_events(
body=es_req,
ignore=404,
**search_args,
)
es_response = search_company_events(body=es_req, ignore=404, **search_args)
aggs_result = es_response.get("aggregations")
if not aggs_result:
@@ -617,17 +657,17 @@ class EventBLL(object):
for hit in variants_bucket["events"]["hits"]["hits"]
]
self.uncompress_plots(events)
return TaskEventsResult(
events=events, total_events=len(events)
)
return TaskEventsResult(events=events, total_events=len(events))
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
"""
Return events and next scroll id from the scrolled query
Release the scroll once it is exhausted
"""
total_events = safe_get(es_res, "hits/total/value", default=0)
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
total_events = nested_get(es_res, ("hits", "total", "value"), default=0)
events = [
doc["_source"] for doc in nested_get(es_res, ("hits", "hits"), default=[])
]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.clear_scroll(next_scroll_id)
@@ -636,9 +676,11 @@ class EventBLL(object):
return events, total_events, next_scroll_id
def get_debug_image_urls(
self, company_id: str, task_id: str, after_key: dict = None
self, company_id: str, task_ids: Sequence[str], after_key: dict = None
) -> Tuple[Sequence[str], Optional[dict]]:
if check_empty_data(self.es, company_id, EventType.metrics_image):
if not task_ids or check_empty_data(
self.es, company_id, EventType.metrics_image
):
return [], None
es_req = {
@@ -654,7 +696,10 @@ class EventBLL(object):
},
"query": {
"bool": {
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
"must": [
{"terms": {"task": task_ids}},
{"exists": {"field": "url"}},
]
}
},
}
@@ -672,9 +717,13 @@ class EventBLL(object):
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
self, company_id: str, task_ids: Sequence[str], scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
if scroll_id == self.empty_scroll:
if (
scroll_id == self.empty_scroll
or not task_ids
or check_empty_data(self.es, company_id, EventType.metrics_plot)
):
return [], None
if scroll_id:
@@ -689,7 +738,7 @@ class EventBLL(object):
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"task": task_ids}},
{"exists": {"field": PlotFields.source_urls}},
]
}
@@ -717,7 +766,7 @@ class EventBLL(object):
size=500,
scroll_id=None,
no_scroll=False,
model_events=False,
last_iters_per_task_metric=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@@ -735,39 +784,56 @@ class EventBLL(object):
if not company_ids:
return TaskEventsResult()
task_ids = (
[task_id]
if isinstance(task_id, str)
else task_id
)
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metrics:
must.append(get_metric_variants_condition(metrics))
if last_iter_count is None or model_events:
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
tasks_iters = self.get_last_iters(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
if last_iters_per_task_metric:
task_metric_iters = self.get_last_iters_per_metric(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"terms": {"iter": last_iters}},
]
}
}
}
for task, last_iters in tasks_iters.items()
if last_iters
]
for (task, metric), last_iters in task_metric_iters.items()
if last_iters
]
else:
tasks_iters = self.get_last_iters(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@@ -808,7 +874,8 @@ class EventBLL(object):
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
es_req = {
"size": 0,
@@ -862,8 +929,10 @@ class EventBLL(object):
}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
@@ -965,13 +1034,77 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters_per_metric(
self,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None,
) -> Mapping[Tuple[str, str], Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = [{"terms": {"task": task_ids}}]
if metrics:
must.append(get_metric_variants_condition(metrics))
max_tasks = min(len(task_ids), 1000)
max_metrics = 10_000 // (max_tasks * iters)
es_req: dict = {
"size": 0,
"aggs": {
"tasks": {
"terms": {"field": "task", "size": max_tasks},
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": max_metrics},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
}
},
}
},
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_ids,
event_type=event_type,
body=es_req,
)
if "aggregations" not in es_res:
return {}
return {
(tb["key"], mb["key"]): [ib["key"] for ib in mb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
for mb in tb["metrics"]["buckets"]
}
def get_last_iters(
self,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None
metrics: MetricVariants = None,
) -> Mapping[str, Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
@@ -987,11 +1120,12 @@ class EventBLL(object):
if metrics:
must.append(get_metric_variants_condition(metrics))
max_tasks = min(len(task_ids), 1000)
es_req: dict = {
"size": 0,
"aggs": {
"tasks": {
"terms": {"field": "task"},
"terms": {"field": "task", "size": max_tasks},
"aggs": {
"iters": {
"terms": {
@@ -1008,7 +1142,10 @@ class EventBLL(object):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_ids, event_type=event_type, body=es_req,
self.es,
company_id=company_ids,
event_type=event_type,
body=es_req,
)
if "aggregations" not in es_res:
@@ -1019,34 +1156,6 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
@staticmethod
def _validate_model_state(
company_id: str, model_id: str, allow_locked: bool = False
):
extra_msg = None
query = Q(id=model_id, company=company_id)
if not allow_locked:
query &= Q(ready__ne=True)
extra_msg = "or model published"
res = Model.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidModelId(
extra_msg, company=company_id, id=model_id
)
@staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
@staticmethod
def _get_events_deletion_params(async_delete: bool) -> dict:
if async_delete:
@@ -1060,40 +1169,52 @@ class EventBLL(object):
return {"refresh": True}
def delete_task_events(
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
self,
company_id,
task_ids: Union[str, Sequence[str]],
wait_for_delete: bool,
model=False,
):
if model:
self._validate_model_state(
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
)
else:
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
es_req = {"query": {"term": {"task": task_id}}}
"""
Delete task events. No check is done for tasks write access
so it should be checked by the calling code
"""
if isinstance(task_ids, str):
task_ids = [task_ids]
deleted = 0
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
async_delete = async_task_events_delete and not wait_for_delete
if async_delete and len(task_ids) < 100:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=task_ids,
)
if total <= async_delete_threshold:
async_delete = False
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
if not async_delete:
deleted += es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
return deleted
def clear_task_log(
self,
company_id: str,
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
include_metrics: Sequence[str] = None,
exclude_metrics: Sequence[str] = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.task_log
):
@@ -1114,8 +1235,16 @@ class EventBLL(object):
}
)
sort = {"timestamp": {"order": "desc"}}
if include_metrics:
must.append({"terms": {"metric": include_metrics}})
more_conditions = {}
if exclude_metrics:
more_conditions = {"must_not": [{"terms": {"metric": exclude_metrics}}]}
es_req = {
"query": {"bool": {"must": must}},
"query": {"bool": {"must": must, **more_conditions}},
**({"sort": sort} if sort else {}),
}
es_res = delete_company_events(
@@ -1127,30 +1256,6 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def delete_multi_task_events(
self, company_id: str, task_ids: Sequence[str], async_delete=False
):
"""
Delete mutliple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
deleted = 0
with translate_errors_context():
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
if not async_delete:
deleted += es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return

View File

@@ -9,7 +9,7 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
class EventType(Enum):
@@ -64,7 +64,7 @@ def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
es_index = get_index_name(company_id, event_type.value)
if not es.indices.exists(es_index):
if not es.indices.exists(index=es_index):
return True
return False
@@ -123,8 +123,8 @@ def get_max_metric_and_variant_counts(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)
metrics_count = safe_get(
es_res, "aggregations/metrics_count/value", max_metrics_count
metrics_count = nested_get(
es_res, ("aggregations", "metrics_count", "value"), max_metrics_count
)
if not metrics_count:
return max_metrics_count, max_variants_count

View File

@@ -21,9 +21,12 @@ from apiserver.bll.event.event_common import (
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@@ -42,6 +45,7 @@ class EventMetrics:
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
model_events: bool = False,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -59,6 +63,7 @@ class EventMetrics:
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
model_events=model_events,
)
def _get_scalar_average_per_iter_core(
@@ -70,6 +75,7 @@ class EventMetrics:
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
model_events: bool = False,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -101,7 +107,22 @@ class EventMetrics:
)
ret = defaultdict(dict)
if not metrics:
return ret
last_metrics = {}
cls_ = Model if model_events else Task
task = cls_.objects(id=task_id).only("last_metrics").first()
if task and task.last_metrics:
for m_data in task.last_metrics.values():
for v_data in m_data.values():
last_metrics[(v_data.metric, v_data.variant)] = v_data
for metric_key, metric_values in metrics:
for variant_key, data in metric_values.items():
last_metrics_data = last_metrics.get((metric_key, variant_key))
if last_metrics_data and last_metrics_data.x_axis_label is not None:
data["x_axis_label"] = last_metrics_data.x_axis_label
ret[metric_key].update(metric_values)
return ret
@@ -112,6 +133,7 @@ class EventMetrics:
samples,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
model_events: bool = False,
):
"""
Compare scalar metrics for different tasks per metric and variant
@@ -135,6 +157,7 @@ class EventMetrics:
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
model_events=model_events,
)
task_ids, company_ids = zip(
*(
@@ -161,8 +184,10 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, companies: TaskCompanies
) -> Mapping[str, dict]:
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, Sequence[dict]]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
@@ -179,7 +204,13 @@ class EventMetrics:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
pool.map(
partial(
self._get_task_single_value_metrics,
metric_variants=metric_variants,
),
companies.items(),
)
),
)
@@ -195,19 +226,19 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, tasks: Tuple[str, Sequence[str]]
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
) -> Sequence[dict]:
company_id, task_ids = tasks
must = [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
@@ -280,7 +311,8 @@ class EventMetrics:
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -332,12 +364,12 @@ class EventMetrics:
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(data, "count/value", default=0)
count = nested_get(data, ("count", "value"), default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
min_index = nested_get(data, ("min_index", "value"), default=0)
max_index = nested_get(data, ("max_index", "value"), default=min_index)
index_range = max_index - min_index + 1
interval = max(1, math.ceil(float(index_range) / samples))
max_samples = math.ceil(float(index_range) / interval)
@@ -366,7 +398,8 @@ class EventMetrics:
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -432,7 +465,9 @@ class EventMetrics:
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
cls,
task_id: str,
metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
@@ -451,12 +486,96 @@ class EventMetrics:
return {"bool": {"must": must}}
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
"""
For the requested tasks return reported metrics and variants
"""
tasks_ids = {
company: [t.id for t in tasks]
for company, tasks in companies.items()
}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
companies_res: Sequence = list(
pool.map(
partial(
self._get_multi_task_metrics,
event_type=event_type,
),
tasks_ids.items(),
)
)
if len(companies_res) == 1:
return companies_res[0]
res = defaultdict(set)
for c_res in companies_res:
for m, vars_ in c_res.items():
res[m].update(vars_)
return {
k: list(v)
for k, v in res.items()
}
def _get_multi_task_metrics(
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
) -> Mapping[str, list]:
company_id, task_ids = company_tasks
if check_empty_data(self.es, company_id, event_type):
return {}
search_args = dict(
es=self.es,
company_id=company_id,
event_type=event_type,
)
query = QueryBuilder.terms("task", task_ids)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
}
}
}
},
}
es_res = search_company_events(
body=es_req,
**search_args,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
return {
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
for mb in aggs_result["metrics"]["buckets"]
}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
For the requested tasks return reported metrics per task
"""
if check_empty_data(self.es, company_id, event_type):
return {}
@@ -495,5 +614,5 @@ class EventMetrics:
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"), default=[])
]

View File

@@ -64,13 +64,13 @@ class EventsIterator:
self,
event_type: EventType,
company_id: str,
task_id: str,
task_ids: Sequence[str],
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
query, _ = self._get_initial_query_and_must(task_ids, metric_variants)
es_req = {
"query": query,
}
@@ -100,7 +100,7 @@ class EventsIterator:
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
query, must = self._get_initial_query_and_must([task_id], metric_variants)
# retrieve the next batch of events
es_req = {
@@ -158,14 +158,14 @@ class EventsIterator:
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
task_ids: Sequence[str], metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
query = {"terms": {"task": task_ids}}
must = [query]
else:
must = [
{"term": {"task": task_id}},
{"terms": {"task": task_ids}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}

View File

@@ -183,7 +183,7 @@ class HistoryDebugImageIterator:
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"sort": [{"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
@@ -242,7 +242,7 @@ class HistoryDebugImageIterator:
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"sort": [{"iter": order}, {"metric": order}, {"variant": order}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
@@ -338,7 +338,7 @@ class HistoryDebugImageIterator:
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"sort": [{"iter": "desc"}, {"url": "desc"}],
"query": {"bool": {"must": must_conditions}},
}

View File

@@ -6,7 +6,6 @@ from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Callable
import attr
import dpath
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
@@ -27,6 +26,7 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
@@ -86,7 +86,7 @@ class MetricEventsIterator:
task_id: company_id
for task_id, company_id in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
self.es, company_id=company_id, event_type=self.event_type
)
}
if not companies:
@@ -305,13 +305,13 @@ class MetricEventsIterator:
return [
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
timestamp=nested_get(metric, ("last_event_timestamp", "value")),
variants=[
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
for variant in nested_get(metric, ("variants", "buckets"))
],
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
for metric in nested_get(es_res, ("aggregations", "metrics", "buckets"))
]
@abc.abstractmethod
@@ -384,7 +384,8 @@ class MetricEventsIterator:
"aggs": {
"events": {
"top_hits": {
"sort": self._get_same_variant_events_order()
"sort": self._get_same_variant_events_order(),
"size": 1,
}
}
},
@@ -430,14 +431,14 @@ class MetricEventsIterator:
def get_iteration_events(it_: dict) -> Sequence:
return [
self._process_event(ev["_source"])
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
for m in nested_get(it_, ("metrics", "buckets"))
for v in nested_get(m, ("variants", "buckets"))
for ev in nested_get(v, ("events", "hits", "hits"))
if is_valid_event(ev["_source"])
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
for it in nested_get(es_res, ("aggregations", "iters", "buckets")):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})

View File

@@ -5,10 +5,11 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo.auth import Identity
from .metadata import Metadata
@@ -28,11 +29,7 @@ class ModelBLL:
@staticmethod
def assert_exists(
company_id,
model_ids,
only=None,
allow_public=False,
return_models=True,
company_id, model_ids, only=None, allow_public=False, return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
@@ -58,14 +55,15 @@ class ModelBLL:
cls,
model_id: str,
company_id: str,
user_id: str,
identity: Identity,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
user_id = identity.user
published_task = None
if model.task and publish_task_func:
task = (
@@ -75,18 +73,25 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, user_id, force_publish_task
model.task, company_id, identity, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
now = datetime.utcnow()
updated = model.update(
upsert=False,
ready=True,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
return updated, published_task
@classmethod
def delete_model(
cls, model_id: str, company_id: str, force: bool
cls, model_id: str, company_id: str, user_id: str, force: bool
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
@@ -112,49 +117,60 @@ class ModelBLL:
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
if task:
now = datetime.utcnow()
if task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
"last_changed_by": user_id,
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
else:
task.update(
pull__models__output__model=model_id,
set__last_change=now,
set__last_changed_by=user_id,
)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
def archive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return unarchived
@@ -170,7 +186,7 @@ class ModelBLL:
[
{
"$match": {
"company": {"$in": [None, "", company]},
"company": {"$in": ["", company]},
"_id": {"$in": model_ids},
}
},
@@ -179,12 +195,43 @@ class ModelBLL:
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
{"$project": {"labels_count": 1}},
]
)
return {
r.pop("_id"): r
for r in result
return {r.pop("_id"): r for r in result}
@staticmethod
def update_statistics(
company_id: str,
user_id: str,
model_id: str,
last_update: datetime = None,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
last_update = last_update or datetime.utcnow()
updates = {
"last_update": datetime.utcnow(),
"last_change": last_update,
"last_changed_by": user_id,
}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)
raw_updates = {}
if last_scalar_events is not None:
raw_updates = {}
if last_scalar_events is not None:
get_last_metric_updates(
task_id=model_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=updates,
model_events=True,
)
ret = Model.objects(id=model_id).update_one(**updates)
if ret and raw_updates:
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
return ret

View File

@@ -5,7 +5,6 @@ from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -87,13 +86,13 @@ class Metadata:
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
def escape_query_parameters(cls, call_data: dict) -> dict:
if not call_data:
return call_data
keys = list(call.data)
keys = list(call_data)
call_data = {
safe_key: call.data[key]
safe_key: call_data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}

View File

@@ -1,8 +1,12 @@
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import Sequence, Dict
from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@@ -22,6 +26,56 @@ class OrgBLL:
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
user_id: str,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
last_changed = {
"set__last_change": datetime.utcnow(),
"set__last_changed_by": user_id,
}
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags, **last_changed,
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags, **last_changed,
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags(
self,
company_id: str,
@@ -50,10 +104,10 @@ class OrgBLL:
return ret
def update_tags(
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, project, tags, system_tags)
tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)

View File

@@ -6,7 +6,6 @@ from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -43,8 +42,8 @@ class _TagsCache:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project__in=project_ids_with_children([project]))
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
# else:
# query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)
@@ -107,7 +106,7 @@ class _TagsCache:
return ret
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
"""
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
@@ -123,7 +122,7 @@ class _TagsCache:
if not fields:
return
self._delete_redis_keys(company_id, projects=[project], fields=fields)
self._delete_redis_keys(company_id, projects=projects, fields=fields)
def reset_tags(self, company_id: str, projects: Sequence[str]):
self._delete_redis_keys(

View File

@@ -1,8 +1,7 @@
import itertools
from collections import defaultdict
from datetime import datetime, timedelta
from functools import reduce
from itertools import groupby
from itertools import groupby, chain
from operator import itemgetter
from typing import (
Sequence,
@@ -15,15 +14,16 @@ from typing import (
Callable,
Mapping,
Any,
Union,
)
from boltons.iterutils import partition
from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.projects import ProjectChildrenType
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model import EntityVisibility, AttributedDocument, User
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
@@ -40,16 +40,26 @@ from .sub_projects import (
_ids_with_children,
_ids_with_parents,
_get_project_depth,
ProjectsChildren,
_get_writable_project_from_name,
)
log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
reports_project_name = ".reports"
datasets_project_name = ".datasets"
pipelines_project_name = ".pipelines"
reports_tag = "reports"
dataset_tag = "dataset"
pipeline_tag = "pipeline"
class ProjectBLL:
child_classes = (Task, Model)
@classmethod
def merge_project(
cls, company, source_id: str, destination_id: str
cls, company: str, source_id: str, destination_id: str
) -> Tuple[int, int, Set[str]]:
"""
Move all the tasks and sub projects from the source project to the destination
@@ -81,7 +91,7 @@ class ProjectBLL:
)
moved_entities = 0
for entity_type in (Task, Model):
for entity_type in cls.child_classes:
moved_entities += entity_type.objects(
company=company,
project=source_id,
@@ -160,6 +170,7 @@ class ProjectBLL:
now = datetime.utcnow()
affected = set()
p: Project
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
@@ -174,6 +185,7 @@ class ProjectBLL:
new_name = fields.pop("name", None)
if new_name:
# noinspection PyTypeChecker
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
@@ -216,6 +228,18 @@ class ProjectBLL:
raise errors.bad_request.ProjectPathExceedsMax(max_depth=max_depth)
name, location = _validate_project_name(name)
existing = _get_writable_project_from_name(
company=company,
name=name,
)
if existing:
raise errors.bad_request.ExpectedUniqueData(
replacement_msg="Project with the same name already exists",
name=name,
company=company,
)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -306,11 +330,12 @@ class ProjectBLL:
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
extra = {}
if hasattr(entity_cls, "last_change"):
extra["set__last_change"] = datetime.utcnow()
if hasattr(entity_cls, "last_changed_by"):
extra["set__last_changed_by"] = user
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
@@ -331,6 +356,17 @@ class ProjectBLL:
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def project_task_fields():
return {
"$project": {
"project": 1,
"status": 1,
"system_tags": 1,
"started": 1,
"completed": 1,
}
}
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
@@ -358,6 +394,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
"$group": {
@@ -396,6 +433,18 @@ class ProjectBLL:
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
{
"$not": {
"$in": [
"$status",
[
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.failed,
],
]
}
},
]
},
"then": 1,
@@ -494,6 +543,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
# for each project
@@ -509,7 +559,7 @@ class ProjectBLL:
def aggregate_project_data(
func: Callable[[T, T], T],
project_ids: Sequence[str],
child_projects: Mapping[str, Sequence[Project]],
child_projects: ProjectsChildren,
data: Mapping[str, T],
) -> Dict[str, T]:
"""
@@ -529,7 +579,10 @@ class ProjectBLL:
@classmethod
def get_dataset_stats(
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
cls,
company: str,
project_ids: Sequence[str],
users: Sequence[str] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
@@ -561,6 +614,140 @@ class ProjectBLL:
for r in Task.aggregate(task_runtime_pipeline)
}
@staticmethod
def _get_projects_children(
project_ids: Sequence[str],
search_hidden: bool,
allowed_ids: Sequence[str],
) -> Tuple[ProjectsChildren, Set[str]]:
child_projects = _get_sub_projects(
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=allowed_ids,
)
return (
child_projects,
{c.id for c in chain.from_iterable(child_projects.values())},
)
@staticmethod
def _get_children_info(
project_ids: Sequence[str], child_projects: ProjectsChildren
) -> dict:
return {
project: sorted(
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
key=itemgetter("name"),
)
for project in project_ids
}
@classmethod
def _get_project_dataset_stats_core(
cls,
company: str,
project_ids: Sequence[str],
project_field: str,
entity_class: Type[AttributedDocument],
include_children: bool = True,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = {}
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids,
search_hidden=True,
allowed_ids=selected_project_ids,
)
project_ids_with_children |= children_ids
pipeline = [
{
"$match": cls.get_match_conditions(
company=company,
project_ids=list(project_ids_with_children),
filter_=filter_,
users=users,
project_field=project_field,
)
},
{"$project": {project_field: 1, "tags": 1}},
{
"$group": {
"_id": f"${project_field}",
"count": {"$sum": 1},
"tags": {"$push": "$tags"},
}
},
]
res = entity_class.aggregate(pipeline)
project_stats = {
result["_id"]: {
"count": result.get("count", 0),
"tags": set(chain.from_iterable(result.get("tags", []))),
}
for result in res
}
def concat_dataset_stats(a: dict, b: dict) -> dict:
return {
"count": a.get("count", 0) + b.get("count", 0),
"tags": a.get("tags", {}) | b.get("tags", {}),
}
top_project_stats = cls.aggregate_project_data(
func=concat_dataset_stats,
project_ids=project_ids,
child_projects=child_projects,
data=project_stats,
)
for _, stat in top_project_stats.items():
stat["tags"] = sorted(list(stat.get("tags", {})))
empty_stats = {"count": 0, "tags": []}
stats = {
project: {"datasets": top_project_stats.get(project, empty_stats)}
for project in project_ids
}
return stats, cls._get_children_info(project_ids, child_projects)
@classmethod
def get_project_dataset_stats(
cls,
company: str,
project_ids: Sequence[str],
include_children: bool = True,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
filter_ = filter_ or {}
filter_system_tags = filter_.get("system_tags")
if not isinstance(filter_system_tags, list):
filter_system_tags = []
if dataset_tag not in filter_system_tags:
filter_system_tags.append(dataset_tag)
filter_["system_tags"] = filter_system_tags
return cls._get_project_dataset_stats_core(
company=company,
project_ids=project_ids,
project_field="parent",
entity_class=Project,
include_children=include_children,
filter_=filter_,
users=users,
selected_project_ids=selected_project_ids,
)
@classmethod
def get_project_stats(
cls,
@@ -571,24 +758,21 @@ class ProjectBLL:
search_hidden: bool = False,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
user_active_project_ids: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(
child_projects = {}
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=user_active_project_ids,
allowed_ids=selected_project_ids,
)
if include_children
else {}
)
project_ids_with_children = set(project_ids) | {
c.id for c in itertools.chain.from_iterable(child_projects.values())
}
project_ids_with_children |= children_ids
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
company,
project_ids=list(project_ids_with_children),
@@ -641,7 +825,7 @@ class ProjectBLL:
}
def sum_runtime(
a: Mapping[str, Mapping], b: Mapping[str, Mapping]
a: Mapping[str, dict], b: Mapping[str, dict]
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
@@ -692,14 +876,7 @@ class ProjectBLL:
for project in project_ids
}
children = {
project: sorted(
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
key=itemgetter("name"),
)
for project in project_ids
}
return stats, children
return stats, cls._get_children_info(project_ids, child_projects)
@classmethod
def get_active_users(
@@ -707,7 +884,7 @@ class ProjectBLL:
company,
project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
) -> Set[Union[str, type(None)]]:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
@@ -724,7 +901,7 @@ class ProjectBLL:
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
for cls_ in cls.child_classes:
res |= set(cls_.objects(query).distinct(field="user"))
return res
@@ -753,46 +930,115 @@ class ProjectBLL:
return tags, system_tags
@classmethod
def get_projects_with_active_user(
def get_projects_with_selected_children(
cls,
company: str,
users: Sequence[str],
users: Sequence[str] = None,
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
children_type: ProjectChildrenType = None,
children_tags: Sequence[str] = None,
children_tags_filter: dict = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
including all the parents of these projects
If project ids are specified then filter the results by these project ids
"""
query = Q(user__in=users)
if not (users or children_type):
raise errors.bad_request.ValidationError(
"Either active users or children_condition should be specified"
)
if allow_public:
query &= get_company_or_none_constraint(company)
query = (
get_company_or_none_constraint(company)
if allow_public
else Q(company=company)
)
if users:
query &= Q(user__in=users)
project_query = None
if children_tags_filter:
child_query = query & GetMixin.get_list_filter_query(
"tags", children_tags_filter
)
elif children_tags:
child_query = query & GetMixin.get_list_field_query("tags", children_tags)
else:
query &= Q(company=company)
child_query = query
if children_type == ProjectChildrenType.dataset:
child_queries = {
Project: child_query
& Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name)
}
elif children_type == ProjectChildrenType.pipeline:
child_queries = {
Project: child_query
& Q(system_tags__in=[pipeline_tag], basename__ne=pipelines_project_name)
}
elif children_type == ProjectChildrenType.report:
child_queries = {Task: child_query & Q(system_tags__in=[reports_tag])}
else:
project_query = query
child_queries = {entity_cls: query for entity_cls in cls.child_classes}
user_projects_query = query
if project_ids:
ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
if project_query:
project_query &= Q(id__in=ids_with_children)
for child_cls in child_queries:
child_queries[child_cls] &= (
Q(parent__in=ids_with_children)
if child_cls is Project
else Q(project__in=ids_with_children)
)
res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = (
set(Project.objects(project_query).scalar("id")) if project_query else set()
)
for cls_, query_ in child_queries.items():
res |= set(
cls_.objects(query_).distinct(
field="id" if cls_ is Project else "project"
)
)
res = list(res)
if not res:
return res, res
user_active_project_ids = _ids_with_parents(res)
selected_project_ids = _ids_with_parents(res)
filtered_ids = (
list(set(user_active_project_ids) & set(project_ids))
list(set(selected_project_ids) & set(project_ids))
if project_ids
else list(user_active_project_ids)
else list(selected_project_ids)
)
return filtered_ids, user_active_project_ids
return filtered_ids, selected_project_ids
@staticmethod
def _get_project_query(
company: str,
projects: Sequence,
include_subprojects: bool = True,
state: Optional[EntityVisibility] = None,
) -> Q:
query = get_company_or_none_constraint(company)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
# else:
# query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
return query
@classmethod
def get_task_parents(
@@ -801,49 +1047,54 @@ class ProjectBLL:
projects: Sequence[str],
include_subprojects: bool,
state: Optional[EntityVisibility] = None,
name: str = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
query = cls._get_project_query(
company_id, projects, include_subprojects=include_subprojects, state=state
)
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
parents = Task.get_many_with_join(
parents: Sequence[dict] = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
query_dict={"name": name} if name else None,
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))
@classmethod
def get_entity_users(
cls,
company: str,
entity_cls: Type[Union[Task, Model]],
projects: Sequence[str],
include_subprojects: bool,
) -> Sequence[dict]:
query = cls._get_project_query(
company, projects, include_subprojects=include_subprojects
)
user_ids = entity_cls.objects(query).distinct(field="user")
if not user_ids:
return []
users = User.objects(id__in=user_ids).only("id", "name")
return [{"id": u.id, "name": u.name} for u in users]
@classmethod
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
query = cls._get_project_query(company, project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@@ -853,10 +1104,7 @@ class ProjectBLL:
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
query = cls._get_project_query(company, project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod
@@ -865,10 +1113,11 @@ class ProjectBLL:
project_ids: Sequence[str],
filter_: Mapping[str, Any],
users: Sequence[str],
project_field: str = "project",
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
"company": {"$in": ["", company]},
project_field: {"$in": project_ids},
}
if users:
conditions["user"] = {"$in": users}
@@ -876,29 +1125,125 @@ class ProjectBLL:
if not filter_:
return conditions
or_conditions = []
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
if not (field_filter and isinstance(field_filter, (list, dict))):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
f"Non empty list or dictionary expected for the field: {field}"
)
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
conditions[field] = {
**({"$in": include} if include else {}),
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
}
if isinstance(field_filter, list):
if not all(isinstance(t, str) for t in field_filter):
raise errors.bad_request.ValidationError(
f"Only string values are allowed in the list filter: {field}"
)
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
op = helper.global_operator
db_query = {op: helper.actions}
else:
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
db_query = helper.db_query
for op, actions in db_query.items():
field_conditions = {}
for action, values in actions.items():
value = list(set(values)) if isinstance(values, list) else values
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)
if op == Q.OR and len(field_conditions) > 1:
or_conditions.append(
{
"$or": [
{field: {db_modifier: cond}}
for db_modifier, cond in field_conditions.items()
]
}
)
else:
conditions[field] = field_conditions
if or_conditions:
if len(or_conditions) == 1:
conditions.update(next(iter(or_conditions)))
else:
conditions["$and"] = [c for c in or_conditions]
return conditions
@classmethod
def _calc_own_datasets_core(
cls,
company: str,
project_ids: Sequence[str],
project_field: str,
entity_class: Type[AttributedDocument],
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
Returns the amount of hyper datasets per requested project
"""
if not project_ids:
return {}
pipeline = [
{
"$match": cls.get_match_conditions(
company=company,
project_ids=project_ids,
filter_=filter_,
users=users,
project_field=project_field,
)
},
{"$project": {project_field: 1}},
{"$group": {"_id": f"${project_field}", "count": {"$sum": 1}}},
]
datasets = {
data["_id"]: data["count"] for data in entity_class.aggregate(pipeline)
}
return {pid: {"own_datasets": datasets.get(pid, 0)} for pid in project_ids}
@classmethod
def calc_own_datasets(
cls,
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
Returns the amount of datasets per requested project
"""
filter_ = filter_ or {}
filter_system_tags = filter_.get("system_tags")
if not isinstance(filter_system_tags, list):
filter_system_tags = []
if dataset_tag not in filter_system_tags:
filter_system_tags.append(dataset_tag)
filter_["system_tags"] = filter_system_tags
return cls._calc_own_datasets_core(
company=company,
project_ids=project_ids,
project_field="parent",
entity_class=Project,
filter_=filter_,
users=users,
)
@classmethod
def calc_own_contents(
cls,
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any] = None,
specific_state: Optional[EntityVisibility] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
@@ -909,6 +1254,20 @@ class ProjectBLL:
if not project_ids:
return {}
if specific_state:
filter_ = filter_ or {}
system_tags_filter = filter_.get("system_tags", [])
archived = EntityVisibility.archived.value
non_archived = f"-{EntityVisibility.archived.value}"
if not any(t in system_tags_filter for t in (archived, non_archived)):
filter_ = {k: v for k, v in filter_.items()}
filter_["system_tags"] = [
archived
if specific_state == EntityVisibility.archived
else non_archived,
*system_tags_filter,
]
pipeline = [
{
"$match": cls.get_match_conditions(

View File

@@ -1,25 +1,34 @@
from collections import defaultdict
from datetime import datetime
from typing import Tuple, Set, Sequence
import attr
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
TaskUrls,
_schedule_for_delete,
schedule_for_delete,
delete_task_events_and_collect_urls,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType, TaskStatus
from .project_bll import (
ProjectBLL,
pipeline_tag,
pipelines_project_name,
dataset_tag,
datasets_project_name,
reports_tag,
)
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -31,30 +40,83 @@ class DeleteProjectResult:
urls: TaskUrls = None
def _get_child_project_ids(
project_id: str,
) -> Tuple[Sequence[str], Sequence[str], Sequence[str]]:
project_ids = _ids_with_children([project_id])
pipeline_ids = list(
Project.objects(
id__in=project_ids,
system_tags__in=[pipeline_tag],
basename__ne=pipelines_project_name,
).scalar("id")
)
dataset_ids = list(
Project.objects(
id__in=project_ids,
system_tags__in=[dataset_tag],
basename__ne=datasets_project_name,
).scalar("id")
)
return project_ids, pipeline_ids, dataset_ids
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model):
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
name = f"non_archived_{cls.__name__.lower()}s"
if not is_pipeline:
ret[name] = cls.objects(**query).count()
else:
ret[name] = (
cls.objects(**query, type=TaskType.controller).count()
if cls == Task
else 0
if pipeline_ids:
pipelines_with_active_controllers = Task.objects(
project__in=pipeline_ids,
type=TaskType.controller,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["pipelines"] = len(pipelines_with_active_controllers)
else:
ret["pipelines"] = 0
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["datasets"] = len(datasets_with_data)
else:
ret["datasets"] = 0
project_ids = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
if project_ids:
in_project_query = Q(project__in=project_ids)
for cls in (Task, Model):
query = (
in_project_query & Q(system_tags__nin=[reports_tag])
if cls is Task
else in_project_query
)
ret[f"{cls.__name__.lower()}s"] = cls.objects(query).count()
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
query & Q(system_tags__nin=[EntityVisibility.archived.value])
).count()
ret["reports"] = Task.objects(
in_project_query & Q(system_tags__in=[reports_tag])
).count()
ret["non_archived_reports"] = Task.objects(
in_project_query
& Q(
system_tags__in=[reports_tag],
system_tags__nin=[EntityVisibility.archived.value],
)
).count()
else:
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = 0
ret[f"non_archived_{cls.__name__.lower()}s"] = 0
ret["reports"] = 0
ret["non_archived_reports"] = 0
return ret
@@ -65,7 +127,7 @@ def delete_project(
project_id: str,
force: bool,
delete_contents: bool,
delete_external_artifacts=True,
delete_external_artifacts: bool,
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
@@ -74,43 +136,62 @@ def delete_project(
raise errors.bad_request.InvalidProjectId(id=project_id)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
"services.async_urls_delete.enabled", True
)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
if not force:
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
if not is_pipeline:
if pipeline_ids:
active_controllers = Task.objects(
project__in=pipeline_ids,
type=TaskType.controller,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if active_controllers:
raise errors.bad_request.ProjectHasPipelines(
"please archive all the controllers or use force=true",
id=project_id,
)
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if datasets_with_data:
raise errors.bad_request.ProjectHasDatasets(
"please delete all the dataset versions or use force=true",
id=project_id,
)
regular_projects = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
if regular_projects:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(**query).only("id")
non_archived = cls.objects(
project__in=regular_projects,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
else:
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
if non_archived:
raise errors.bad_request.ProjectHasTasks(
"please archive all the runs inside the project", id=project_id
)
raise error("use force=true", id=project_id)
if not delete_contents:
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=updated_count)
disassociated = defaultdict(int)
for cls in ProjectBLL.child_classes:
disassociated[cls] = cls.objects(project__in=project_ids).update(
project=None
)
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
scheduled = _schedule_for_delete(
scheduled = schedule_for_delete(
task_id=project_id,
company=company,
user=user,
@@ -124,7 +205,6 @@ def delete_project(
deleted_models=deleted_models,
urls=TaskUrls(
model_urls=list(model_urls),
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
),
)
@@ -135,7 +215,9 @@ def delete_project(
return res, affected
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
def _delete_tasks(
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
@@ -146,14 +228,21 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
if not tasks:
return 0, set(), set()
task_ids = {t.id for t in tasks}
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
task_ids = list({t.id for t in tasks})
now = datetime.utcnow()
Task.objects(parent__in=task_ids, project__nin=projects).update(
parent=None,
last_change=now,
last_changed_by=user,
)
Model.objects(task__in=task_ids, project__nin=projects).update(
task=None,
last_change=now,
last_changed_by=user,
)
event_urls, artifact_urls = set(), set()
artifact_urls = set()
for task in tasks:
event_urls.update(collect_debug_image_urls(company, task.id))
event_urls.update(collect_plot_image_urls(company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls.update(
{
@@ -163,15 +252,16 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
}
)
event_bll.delete_multi_task_events(
company, list(task_ids), async_delete=async_events_delete
event_urls = delete_task_events_and_collect_urls(
company=company, task_ids=task_ids, wait_for_delete=False
)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(
company: str, projects: Sequence[str]
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
@@ -182,39 +272,53 @@ def _delete_models(
return 0, set(), set()
model_ids = list({m.id for m in models})
deleted = "__DELETED__"
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
update={"$set": {"models.input.$[elem].model": deleted}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
now = datetime.utcnow()
# update published tasks
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
"status": TaskStatus.published,
},
update={
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
"last_changed_by": user,
}
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
# update unpublished tasks
Task.objects(
id__in=model_tasks,
project__nin=projects,
status__ne=TaskStatus.published,
).update(
pull__models__output__model__in=model_ids,
set__last_change=now,
set__last_changed_by=user,
)
event_urls, model_urls = set(), set()
for m in models:
event_urls.update(collect_debug_image_urls(company, m.id))
event_urls.update(collect_plot_image_urls(company, m.id))
if m.uri:
model_urls.add(m.uri)
event_bll.delete_multi_task_events(
company, model_ids, async_delete=async_events_delete
model_urls = {m.uri for m in models if m.uri}
event_urls = delete_task_events_and_collect_urls(
company=company, task_ids=model_ids, model=True, wait_for_delete=False
)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -47,7 +47,7 @@ class ProjectQueries:
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": {"$in": ["", company_id]}}
return {"company": company_id}
@@ -140,7 +140,12 @@ class ProjectQueries:
name: str,
include_subprojects: bool,
allow_public: bool = True,
pattern: str = None,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
page = max(0, page)
page_size = max(1, page_size)
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
@@ -160,7 +165,20 @@ class ProjectQueries:
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
redis_key = "_".join(
str(part)
for part in (
"hyperparam_values",
company_id,
"_".join(project_ids),
section,
name,
allow_public,
pattern,
page,
page_size,
)
)
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
@@ -172,19 +190,27 @@ class ProjectQueries:
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
match_condition = {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
if pattern:
match_condition["$expr"] = {
"$regexMatch": {
"input": f"${key_path}.value",
"regex": pattern,
"options": "i",
}
},
}
pipeline = [
{"$match": match_condition},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
@@ -209,13 +235,19 @@ class ProjectQueries:
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
ids: Sequence[str],
model_metrics: bool = False,
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
**({"_id": {"$in": ids}} if ids else {}),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
@@ -246,7 +278,8 @@ class ProjectQueries:
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
entity_cls = Model if model_metrics else Task
result = entity_cls.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
@@ -306,7 +339,11 @@ class ProjectQueries:
key: str,
include_subprojects: bool,
allow_public: bool = True,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
page = max(0, page)
page_size = max(1, page_size)
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
@@ -326,7 +363,7 @@ class ProjectQueries:
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}_{page}_{page_size}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
@@ -334,7 +371,6 @@ class ProjectQueries:
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
@@ -346,7 +382,8 @@ class ProjectQueries:
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,

View File

@@ -2,6 +2,8 @@ import itertools
from datetime import datetime
from typing import Tuple, Optional, Sequence, Mapping
from boltons.iterutils import first
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
@@ -14,14 +16,16 @@ def _get_project_depth(project_name: str) -> int:
return len(list(filter(None, project_name.split(name_separator))))
def _validate_project_name(project_name: str) -> Tuple[str, str]:
def _validate_project_name(project_name: str, raise_if_empty=True) -> Tuple[str, str]:
"""
Remove redundant '/' characters. Ensure that the project name is not empty
Return the cleaned up project name and location
"""
name_parts = list(filter(None, project_name.split(name_separator)))
name_parts = [p.strip() for p in project_name.split(name_separator) if p]
if not name_parts:
raise errors.bad_request.InvalidProjectName(name=project_name)
if raise_if_empty:
raise errors.bad_request.InvalidProjectName(name=project_name)
return "", ""
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
@@ -34,7 +38,7 @@ def _ensure_project(
If needed auto-create the project and all the missing projects in the path to it
Return the project
"""
name = name.strip(name_separator)
name, location = _validate_project_name(name, raise_if_empty=False)
if not name:
return None
@@ -43,7 +47,6 @@ def _ensure_project(
return project
now = datetime.utcnow()
name, location = _validate_project_name(name)
project = Project(
id=database.utils.id(),
user=user,
@@ -95,10 +98,24 @@ def _get_writable_project_from_name(
"""
Return a project from name. If the project not found then return None
"""
qs = Project.objects(company=company, name=name)
qs = Project.objects(company__in=[company, ""], name=name)
if _only:
if "company" not in _only:
_only = ["company", *_only]
qs = qs.only(*_only)
return qs.first()
projects = list(qs)
if not projects:
return
project = first(p for p in projects if p.company == company)
if not project:
raise errors.bad_request.PublicProjectExists(name=name)
return project
ProjectsChildren = Mapping[str, Sequence[Project]]
def _get_sub_projects(
@@ -106,7 +123,7 @@ def _get_sub_projects(
_only: Sequence[str] = ("id", "path"),
search_hidden=True,
allowed_ids: Sequence[str] = None,
) -> Mapping[str, Sequence[Project]]:
) -> ProjectsChildren:
"""
Return the list of child projects of all the levels for the parent project ids
"""
@@ -140,8 +157,8 @@ def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with the ids of all the subprojects
"""
subprojects = Project.objects(path__in=project_ids).only("id")
return list({*project_ids, *(child.id for child in subprojects)})
children_ids = Project.objects(path__in=project_ids).scalar("id")
return list({*project_ids, *children_ids})
def _update_subproject_names(
@@ -159,14 +176,14 @@ def _update_subproject_names(
now = datetime.utcnow()
for child in children:
child_suffix = name_separator.join(
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
child.name.split(name_separator)[len(old_name.split(name_separator)):]
)
updates = {
"name": name_separator.join((project.name, child_suffix)),
"last_update": now,
}
if update_path:
updates["path"] = project.path + child.path[len(old_path) :]
updates["path"] = project.path + child.path[len(old_path):]
updated += child.update(upsert=False, **updates)
return updated

View File

@@ -9,20 +9,35 @@ RANGE_IGNORE_VALUE = -1
class Builder:
@staticmethod
def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
def dates_range(
from_date: Optional[Union[int, float]] = None,
to_date: Optional[Union[int, float]] = None,
) -> dict:
assert (
from_date or to_date
), "range condition requires that at least one of from_date or to_date specified"
conditions = {}
if from_date:
conditions["gte"] = int(from_date)
if to_date:
conditions["lte"] = int(to_date)
return {
"range": {
"timestamp": {
"gte": int(from_date),
"lte": int(to_date),
**conditions,
"format": "epoch_second",
}
}
}
@staticmethod
def terms(field: str, values: Iterable[str]) -> dict:
def terms(field: str, values: Iterable) -> dict:
if isinstance(values, str):
assert not isinstance(values, str), "apparently 'term' should be used here"
return {"terms": {field: list(values)}}
@staticmethod
def term(field: str, value) -> dict:
return {"term": {field: value}}
@staticmethod
def normalize_range(

View File

@@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from typing import Sequence, Optional, Tuple, Union, Iterable
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -16,6 +16,8 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
log = config.logger(__file__)
MOVE_FIRST = "first"
MOVE_LAST = "last"
class QueueBLL(object):
@@ -32,6 +34,7 @@ class QueueBLL(object):
def create(
company_id: str,
name: str,
display_name: str = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[dict] = None,
@@ -44,6 +47,7 @@ class QueueBLL(object):
company=company_id,
created=now,
name=name,
display_name=display_name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
@@ -133,44 +137,78 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
def _update_task_status_on_removal_from_queue(
self,
company_id: str,
user_id: str,
task_ids: Iterable[str],
queue_id: str,
reason: str
) -> Sequence[str]:
from apiserver.bll.task import ChangeStatusRequest
tasks = []
for task_id in task_ids:
try:
task = Task.get(
company=company_id,
id=task_id,
execution__queue=queue_id,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
tasks.append(task.id)
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=reason,
status_message="",
user_id=user_id,
force=True,
).execute(
enqueue_status=None,
unset__execution__queue=1,
)
except Exception as ex:
log.error(
f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}"
)
return tasks
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
from apiserver.bll.task import ChangeStatusRequest
for item in queue.entries:
try:
task = Task.get_for_writing(
company=company_id,
id=item.task,
_only=["id", "status", "enqueue_status", "project"],
)
if not task:
continue
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
f"Failed dequeuing task {item.task} from queue: {queue_id}"
)
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if not queue.entries:
queue.delete()
return []
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was deleted",
)
queue.delete()
return tasks
def get_all(
self,
@@ -234,6 +272,7 @@ class QueueBLL(object):
{
"name": w.id,
"ip": w.ip,
"key": w.key,
"task": w.task.to_struct() if w.task else None,
}
for w in queue_workers.get(item["id"], [])
@@ -297,7 +336,36 @@ class QueueBLL(object):
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
def clear_queue(
self,
company_id: str,
user_id: str,
queue_id: str,
):
queue = Queue.objects(company=company_id, id=queue_id).first()
if not queue:
raise errors.bad_request.InvalidQueueId(
queue=queue_id
)
if not queue.entries:
return []
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was cleared",
)
queue.update(entries=[])
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return tasks
def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
@@ -312,6 +380,14 @@ class QueueBLL(object):
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
if res and update_task_status:
self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids=[task_id],
queue_id=queue_id,
reason=f"Task was removed from the queue {queue_id}",
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
@@ -319,46 +395,131 @@ class QueueBLL(object):
return len(entries_to_remove) if res else 0
def reposition_task(
self,
company_id: str,
queue_id: str,
task_id: str,
pos_func: Callable[[int], int],
self, company_id: str, queue_id: str, task_id: str, move_count: Union[int, str],
) -> int:
"""
Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
def get_queue_and_task_position():
q = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
return q, next(i for i, e in enumerate(q.entries) if e.task == task_id)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
new_position = pos_func(position)
with translate_errors_context():
queue, position = get_queue_and_task_position()
if move_count == MOVE_FIRST:
new_position = 0
elif move_count == MOVE_LAST:
new_position = len(queue.entries) - 1
else:
new_position = position + move_count
if new_position == position:
return new_position
if new_position != position:
entry = queue.entries[position]
query = dict(id=queue_id, company=company_id)
updated = Queue.objects(entries__task=task_id, **query).update_one(
pull__entries=entry, last_update=datetime.utcnow()
)
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
without_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$ne": ["$$entry.task", task_id]},
}
}
task_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$eq": ["$$entry.task", task_id]},
}
}
if move_count == MOVE_FIRST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [task_entry, without_entry]}
}
}
]
elif move_count == MOVE_LAST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [without_entry, task_entry]}
}
}
]
else:
operations = [
{
"$set": {
"new_pos": {
"$add": [
{"$indexOfArray": ["$entries.task", task_id]},
move_count,
]
},
"without_entry": without_entry,
"task_entry": task_entry,
}
},
{
"$set": {
"entries": {
"$switch": {
"branches": [
{
"case": {"$lte": ["$new_pos", 0]},
"then": {
"$concatArrays": [
"$task_entry",
"$without_entry",
]
},
},
{
"case": {
"$gte": [
"$new_pos",
{"$size": "$without_entry"},
]
},
"then": {
"$concatArrays": [
"$without_entry",
"$task_entry",
]
},
},
],
"default": {
"$concatArrays": [
{"$slice": ["$without_entry", "$new_pos"]},
"$task_entry",
{
"$slice": [
"$without_entry",
"$new_pos",
{"$size": "$without_entry"},
]
},
]
},
}
}
}
},
{"$unset": ["new_pos", "without_entry", "task_entry"]},
]
return new_position
updated = Queue.objects(
id=queue_id, company=company_id, entries__task=task_id
).update_one(__raw__=operations)
if not updated:
raise errors.bad_request.FailedAddingDuringReposition(task=task_id)
return get_queue_and_task_position()[1]
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(
@@ -366,7 +527,7 @@ class QueueBLL(object):
[
{
"$match": {
"company": {"$in": [None, "", company]},
"company": {"$in": ["", company]},
"_id": queue_id,
}
},

View File

@@ -80,7 +80,7 @@ class QueueMetrics:
logged = 0
for q in queues:
queue_doc = make_doc(q)
self.es.index(index=es_index, body=queue_doc)
self.es.index(index=es_index, document=queue_doc)
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
redis.set(redis_key, json.dumps(queue_doc))
logged += 1

View File

@@ -0,0 +1,376 @@
from datetime import datetime, timedelta, timezone
from enum import Enum, auto
from operator import attrgetter
from time import time
from typing import Optional, Sequence, Union
import attr
from boltons.iterutils import chunked_iter, bucketize
from pyhocon import ConfigTree
from apiserver.apimodels.serving import (
ServingContainerEntry,
RegisterRequest,
StatusReportRequest,
)
from apiserver.apimodels.workers import MachineStats
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.redis_manager import redman
from .stats import ServingStats
log = config.logger(__file__)
class ServingBLL:
def __init__(self, redis=None):
self.conf = config.get("services.serving", ConfigTree())
self.redis = redis or redman.connection("workers")
@staticmethod
def _get_url_key(company: str, url: str):
return f"serving_url_{company}_{url}"
@staticmethod
def _get_container_key(company: str, container_id: str) -> str:
"""Build redis key from company and container_id"""
return f"serving_container_{company}_{container_id}"
def _save_serving_container_entry(self, entry: ServingContainerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
expiration = int(time()) + entry.register_timeout
container_item = {entry.key: expiration}
self.redis.zadd(url_key, container_item)
# make sure that url set will not get stuck in redis
# indefinitely in case no more containers report to it
self.redis.expire(url_key, max(3600, entry.register_timeout))
def _get_serving_container_entry(
self, company_id: str, container_id: str
) -> Optional[ServingContainerEntry]:
"""
Get a container entry for the provided container ID.
"""
key = self._get_container_key(company_id, container_id)
data = self.redis.get(key)
if not data:
return
try:
entry = ServingContainerEntry.from_json(data)
return entry
except Exception as e:
msg = "Failed parsing container entry"
log.exception(f"{msg}: {str(e)}")
def register_serving_container(
self,
company_id: str,
request: RegisterRequest,
ip: str = "",
) -> ServingContainerEntry:
"""
Register a serving container
"""
now = datetime.now(timezone.utc)
key = self._get_container_key(company_id, request.container_id)
entry = ServingContainerEntry(
**request.to_struct(),
key=key,
company_id=company_id,
ip=ip,
register_time=now,
register_timeout=request.timeout,
last_activity_time=now,
)
self._save_serving_container_entry(entry)
return entry
def unregister_serving_container(
self,
company_id: str,
container_id: str,
) -> None:
"""
Unregister a serving container
"""
entry = self._get_serving_container_entry(company_id, container_id)
if entry:
url_key = self._get_url_key(entry.company_id, entry.endpoint_url)
self.redis.zrem(url_key, entry.key)
key = self._get_container_key(company_id, container_id)
res = self.redis.delete(key)
if res:
return
if not self.conf.get("container_auto_unregister", True):
raise errors.bad_request.ContainerNotRegistered(container=container_id)
def container_status_report(
self,
company_id: str,
report: StatusReportRequest,
ip: str = "",
) -> None:
"""
Serving container status report
"""
container_id = report.container_id
now = datetime.now(timezone.utc)
entry = self._get_serving_container_entry(company_id, container_id)
if entry:
ip = ip or entry.ip
register_time = entry.register_time
register_timeout = entry.register_timeout
else:
if not self.conf.get("container_auto_register", True):
raise errors.bad_request.ContainerNotRegistered(container=container_id)
ip = ip
register_time = now
register_timeout = int(
self.conf.get("default_container_timeout_sec", 10 * 60)
)
key = self._get_container_key(company_id, container_id)
entry = ServingContainerEntry(
**report.to_struct(),
key=key,
company_id=company_id,
ip=ip,
register_time=register_time,
register_timeout=register_timeout,
last_activity_time=now,
)
self._save_serving_container_entry(entry)
ServingStats.log_stats_to_es(entry)
def _get_all(
self,
company_id: str,
) -> Sequence[ServingContainerEntry]:
keys = list(self.redis.scan_iter(self._get_container_key(company_id, "*")))
entries = []
for keys in chunked_iter(keys, 1000):
data = self.redis.mget(keys)
if not data:
continue
for d in data:
try:
entries.append(ServingContainerEntry.from_json(d))
except Exception as ex:
log.error(f"Failed parsing container entry {str(ex)}")
return entries
@attr.s(auto_attribs=True)
class Counter:
class AggType(Enum):
avg = auto()
max = auto()
total = auto()
count = auto()
name: str
field: str
agg_type: AggType
float_precision: int = None
_max: Union[int, float, datetime] = attr.field(init=False, default=None)
_total: Union[int, float] = attr.field(init=False, default=0)
_count: int = attr.field(init=False, default=0)
def add(self, entry: ServingContainerEntry):
value = getattr(entry, self.field, None)
if value is None:
return
self._count += 1
if self.agg_type == self.AggType.max:
self._max = value if self._max is None else max(self._max, value)
else:
self._total += value
def __call__(self):
if self.agg_type == self.AggType.count:
return self._count
if self.agg_type == self.AggType.max:
return self._max
if self.agg_type == self.AggType.total:
return self._total
if not self._count:
return None
avg = self._total / self._count
return (
round(avg, self.float_precision) if self.float_precision else round(avg)
)
def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict:
counters = [
self.Counter(
name="uptime_sec",
field="uptime_sec",
agg_type=self.Counter.AggType.max,
),
self.Counter(
name="requests",
field="requests_num",
agg_type=self.Counter.AggType.total,
),
self.Counter(
name="requests_min",
field="requests_min",
agg_type=self.Counter.AggType.avg,
float_precision=2,
),
self.Counter(
name="latency_ms",
field="latency_ms",
agg_type=self.Counter.AggType.avg,
),
self.Counter(
name="last_update",
field="last_activity_time",
agg_type=self.Counter.AggType.max,
),
]
for entry in entries:
for counter in counters:
counter.add(entry)
first_entry = entries[0]
ret = {
"endpoint": first_entry.endpoint_name,
"model": first_entry.model_name,
"url": first_entry.endpoint_url,
"instances": len(entries),
**{counter.name: counter() for counter in counters},
}
ret["last_update"] = ret.get("last_update")
return ret
def get_endpoints(self, company_id: str):
"""
Group instances by urls and return a summary for each url
Do not return data for "loading" instances that have no url
"""
entries = self._get_all(company_id)
by_url = bucketize(entries, key=attrgetter("endpoint_url"))
by_url.pop(None, None)
return [self._get_summary(url_entries) for url_entries in by_url.values()]
def _get_endpoint_entries(
self, company_id, endpoint_url: Union[str, None]
) -> Sequence[ServingContainerEntry]:
url_key = self._get_url_key(company_id, endpoint_url)
timestamp = int(time())
self.redis.zremrangebyscore(url_key, min=0, max=timestamp)
container_keys = {key.decode() for key in self.redis.zrange(url_key, 0, -1)}
if not container_keys:
return []
entries = []
found_keys = set()
data = self.redis.mget(container_keys) or []
for d in data:
try:
entry = ServingContainerEntry.from_json(d)
if entry.endpoint_url == endpoint_url:
entries.append(entry)
found_keys.add(entry.key)
except Exception as ex:
log.error(f"Failed parsing container entry {str(ex)}")
missing_keys = container_keys - found_keys
if missing_keys:
self.redis.zrem(url_key, *missing_keys)
return entries
def get_loading_instances(self, company_id: str):
entries = self._get_endpoint_entries(company_id, None)
return [
{
"id": entry.container_id,
"endpoint": entry.endpoint_name,
"url": entry.endpoint_url,
"model": entry.model_name,
"model_source": entry.model_source,
"model_version": entry.model_version,
"preprocess_artifact": entry.preprocess_artifact,
"input_type": entry.input_type,
"input_size": entry.input_size,
"uptime_sec": entry.uptime_sec,
"age_sec": int((datetime.now(timezone.utc) - entry.register_time).total_seconds()),
"last_update": entry.last_activity_time,
}
for entry in entries
]
def get_endpoint_details(self, company_id, endpoint_url: str) -> dict:
entries = self._get_endpoint_entries(company_id, endpoint_url)
if not entries:
raise errors.bad_request.NoContainersForUrl(url=endpoint_url)
instances = []
entry: ServingContainerEntry
for entry in entries:
instances.append(
{
"endpoint": entry.endpoint_name,
"model": entry.model_name,
"url": entry.endpoint_url,
}
)
def get_machine_stats_data(machine_stats: MachineStats) -> dict:
ret = {"cpu_count": 0, "gpu_count": 0}
if not machine_stats:
return ret
for value, field in (
(machine_stats.cpu_usage, "cpu_count"),
(machine_stats.gpu_usage, "gpu_count"),
):
if value is None:
continue
ret[field] = len(value) if isinstance(value, (list, tuple)) else 1
return ret
first_entry = entries[0]
return {
"endpoint": first_entry.endpoint_name,
"model": first_entry.model_name,
"url": first_entry.endpoint_url,
"preprocess_artifact": first_entry.preprocess_artifact,
"input_type": first_entry.input_type,
"input_size": first_entry.input_size,
"model_source": first_entry.model_source,
"model_version": first_entry.model_version,
"uptime_sec": max(e.uptime_sec for e in entries),
"last_update": max(e.last_activity_time for e in entries),
"instances": [
{
"id": entry.container_id,
"uptime_sec": entry.uptime_sec,
"requests": entry.requests_num,
"requests_min": entry.requests_min,
"latency_ms": entry.latency_ms,
"last_update": entry.last_activity_time,
"reference": [ref.to_struct() for ref in entry.reference]
if isinstance(entry.reference, list)
else entry.reference,
**get_machine_stats_data(entry.machine_stats),
}
for entry in entries
],
}

View File

@@ -0,0 +1,340 @@
from collections import defaultdict
from datetime import datetime, timezone
from enum import Enum
from typing import Tuple, Optional, Sequence
from elasticsearch import Elasticsearch
from apiserver.apimodels.serving import (
ServingContainerEntry,
GetEndpointMetricsHistoryRequest,
MetricType,
)
from apiserver.apierrors import errors
from apiserver.utilities.dicts import nested_get
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.es_factory import es_factory
class _AggregationType(Enum):
avg = "avg"
sum = "sum"
class ServingStats:
min_chart_interval = config.get("services.serving.min_chart_interval_sec", 40)
es: Elasticsearch = es_factory.connect("workers")
@classmethod
def _serving_stats_prefix(cls, company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"serving_stats_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():
"""Get the index name suffix for storing current month data"""
return datetime.now(timezone.utc).strftime("%Y-%m")
@staticmethod
def _get_average_value(value) -> Tuple[Optional[float], Optional[int]]:
if value is None:
return None, None
if isinstance(value, (list, tuple)):
count = len(value)
if not count:
return None, None
return sum(value) / count, count
return value, 1
@classmethod
def log_stats_to_es(
cls,
entry: ServingContainerEntry,
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: The amount of logged documents
"""
company_id = entry.company_id
es_index = (
f"{cls._serving_stats_prefix(company_id)}" f"{cls._get_es_index_suffix()}"
)
entry_data = entry.to_struct()
doc = {
"timestamp": es_factory.get_timestamp_millis(),
**{
field: entry_data.get(field)
for field in (
"container_id",
"company_id",
"endpoint_url",
"requests_num",
"requests_min",
"uptime_sec",
"latency_ms",
)
},
}
stats = entry_data.get("machine_stats")
if stats:
for category in ("cpu", "gpu"):
usage, num = cls._get_average_value(stats.get(f"{category}_usage"))
doc.update({f"{category}_usage": usage, f"{category}_num": num})
for category in ("memory", "gpu_memory"):
free, _ = cls._get_average_value(stats.get(f"{category}_free"))
used, _ = cls._get_average_value(stats.get(f"{category}_used"))
doc.update(
{
f"{category}_free": free,
f"{category}_used": used,
f"{category}_total": round((free or 0) + (used or 0), 3),
}
)
doc.update(
{
field: stats.get(field)
for field in ("disk_free_home", "network_rx", "network_tx")
}
)
cls.es.index(index=es_index, document=doc)
return 1
@staticmethod
def round_series(values: Sequence, koeff) -> list:
return [round(v * koeff, 2) if v else 0 for v in values]
_mb_to_gb = 1 / 1024
agg_fields = {
MetricType.requests: (
"requests_num",
"Number of Requests",
_AggregationType.sum,
None,
),
MetricType.requests_min: (
"requests_min",
"Requests per Minute",
_AggregationType.sum,
None,
),
MetricType.latency_ms: (
"latency_ms",
"Average Latency (ms)",
_AggregationType.avg,
None,
),
MetricType.cpu_count: ("cpu_num", "CPU Count", _AggregationType.sum, None),
MetricType.gpu_count: ("gpu_num", "GPU Count", _AggregationType.sum, None),
MetricType.cpu_util: (
"cpu_usage",
"Average CPU Load (%)",
_AggregationType.avg,
None,
),
MetricType.gpu_util: (
"gpu_usage",
"Average GPU Utilization (%)",
_AggregationType.avg,
None,
),
MetricType.ram_total: (
"memory_total",
"RAM Total (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.ram_used: (
"memory_used",
"RAM Used (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.ram_free: (
"memory_free",
"RAM Free (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_total: (
"gpu_memory_total",
"GPU RAM Total (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_used: (
"gpu_memory_used",
"GPU RAM Used (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.gpu_ram_free: (
"gpu_memory_free",
"GPU RAM Free (GB)",
_AggregationType.sum,
_mb_to_gb,
),
MetricType.network_rx: (
"network_rx",
"Network Throughput RX (MBps)",
_AggregationType.sum,
None,
),
MetricType.network_tx: (
"network_tx",
"Network Throughput TX (MBps)",
_AggregationType.sum,
None,
),
}
@classmethod
def get_endpoint_metrics(
cls,
company_id: str,
metrics_request: GetEndpointMetricsHistoryRequest,
) -> dict:
from_date = metrics_request.from_date
to_date = metrics_request.to_date
if from_date >= to_date:
raise errors.bad_request.FieldsValueError(
"from_date must be less than to_date"
)
metric_type = metrics_request.metric_type
agg_data = cls.agg_fields.get(metric_type)
if not agg_data:
raise NotImplemented(f"Charts for {metric_type} not implemented")
agg_field, title, agg_type, multiplier = agg_data
if agg_type == _AggregationType.sum:
instance_sum_type = "sum_bucket"
else:
instance_sum_type = "avg_bucket"
interval = max(metrics_request.interval, cls.min_chart_interval)
endpoint_url = metrics_request.endpoint_url
hist_ret = {
"computed_interval": interval,
"total": {
"title": title,
"dates": [],
"values": [],
},
"instances": {},
}
must_conditions = [
QueryBuilder.term("company_id", company_id),
QueryBuilder.term("endpoint_url", endpoint_url),
QueryBuilder.dates_range(from_date, to_date),
]
query = {"bool": {"must": must_conditions}}
es_index = f"{cls._serving_stats_prefix(company_id)}*"
res = cls.es.search(
index=es_index,
size=0,
query=query,
aggs={"instances": {"terms": {"field": "container_id"}}},
)
instance_buckets = nested_get(res, ("aggregations", "instances", "buckets"))
if not instance_buckets:
return hist_ret
instance_keys = {ib["key"] for ib in instance_buckets}
must_conditions.append(QueryBuilder.terms("container_id", instance_keys))
query = {"bool": {"must": must_conditions}}
sample_func = "avg" if metric_type != MetricType.requests else "max"
aggs = {
"instances": {
"terms": {
"field": "container_id",
"size": max(len(instance_keys), 10),
},
"aggs": {
"sample": {sample_func: {"field": agg_field}},
},
},
"total_instances": {
instance_sum_type: {
"gap_policy": "insert_zeros",
"buckets_path": "instances>sample",
}
},
}
hist_params = {}
if metric_type == MetricType.requests:
hist_params["min_doc_count"] = 1
else:
hist_params["extended_bounds"] = {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
aggs = {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
**hist_params,
},
"aggs": aggs,
}
}
filter_path = None
if not metrics_request.instance_charts:
filter_path = "aggregations.dates.buckets.total_instances"
data = cls.es.search(
index=es_index,
size=0,
query=query,
aggs=aggs,
filter_path=filter_path,
)
agg_res = data.get("aggregations")
if not agg_res:
return hist_ret
dates_ = []
total = []
instances = defaultdict(list)
# remove last interval if it's incomplete. Allow 10% tolerance
last_valid_timestamp = (to_date - 0.9 * interval) * 1000
for point in agg_res["dates"]["buckets"]:
date_ = point["key"]
if date_ > last_valid_timestamp:
break
dates_.append(date_)
total.append(nested_get(point, ("total_instances", "value"), 0))
if metrics_request.instance_charts:
found_keys = set()
for instance in nested_get(point, ("instances", "buckets"), []):
instances[instance["key"]].append(
nested_get(instance, ("sample", "value"), 0)
)
found_keys.add(instance["key"])
for missing_key in instance_keys - found_keys:
instances[missing_key].append(0)
koeff = multiplier if multiplier else 1.0
hist_ret["total"]["dates"] = dates_
hist_ret["total"]["values"] = cls.round_series(total, koeff)
hist_ret["instances"] = {
key: {
"title": key,
"dates": dates_,
"values": cls.round_series(values, koeff),
}
for key, values in sorted(instances.items(), key=lambda p: p[0])
}
return hist_ret

View File

@@ -8,8 +8,7 @@ from typing import Sequence, Optional
import dpath
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter, Retry
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.bll.util import get_server_uuid
@@ -19,7 +18,7 @@ from apiserver.config.info import get_deployment_type
from apiserver.database.model import Company, User
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import dumps
from apiserver.version import __version__ as current_version
from .resource_monitor import ResourceMonitor, stat_threads
@@ -163,7 +162,7 @@ class StatisticsReporter:
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
names = {"cpu": "num_cores"}
return {
names[c["key"]]: safe_get(c, "count/value")
names[c["key"]]: nested_get(c, ("count", "value"))
for c in categories
if c["key"] in names
}
@@ -176,21 +175,21 @@ class StatisticsReporter:
}
return {
names[m["key"]]: {
"min": safe_get(m, "min/value"),
"max": safe_get(m, "max/value"),
"avg": safe_get(m, "avg/value"),
"min": nested_get(m, ("min", "value")),
"max": nested_get(m, ("max", "value")),
"avg": nested_get(m, ("avg", "value")),
}
for m in metrics
if m["key"] in names
}
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {
key: {
"interval_sec": agent_resource_threshold_sec,
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
**_get_cardinality_fields(nested_get(b, ("categories", "buckets"), [])),
**_get_metric_fields(nested_get(b, ("metrics", "buckets"), [])),
}
}
for b in buckets
@@ -228,7 +227,7 @@ class StatisticsReporter:
},
}
res = cls._run_worker_stats_query(company_id, es_req)
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
buckets = nested_get(res, ("aggregations", "workers", "buckets"), default=[])
return {
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
for b in buckets
@@ -255,6 +254,14 @@ class StatisticsReporter:
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$project": {
"last_worker": 1,
"last_update": 1,
"started": 1,
"last_iteration": 1,
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,

View File

@@ -1,14 +1,32 @@
import json
import os
import tempfile
from copy import copy
from datetime import datetime
from typing import Optional, Sequence
import attr
from boltons.cacheutils import cachedproperty
from clearml.backend_config.bucket_config import (
S3BucketConfigurations,
AzureContainerConfigurations,
GSBucketConfigurations,
AzureContainerConfig,
GSBucketConfig,
S3BucketConfig,
)
from apiserver.apierrors import errors
from apiserver.apimodels.storage import SetSettingsRequest
from apiserver.config_repo import config
from apiserver.database.model.storage_settings import (
StorageSettings,
GoogleBucketSettings,
AWSSettings,
AzureStorageSettings,
GoogleStorageSettings,
)
from apiserver.database.utils import id as db_id
log = config.logger(__file__)
@@ -32,17 +50,224 @@ class StorageBLL:
def get_azure_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
) -> AzureContainerConfigurations:
return copy(self._default_azure_configs)
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("azure").first()
)
if not db_settings or not db_settings.azure:
return copy(self._default_azure_configs)
azure = db_settings.azure
return AzureContainerConfigurations(
container_configs=[
AzureContainerConfig(**entry.to_proper_dict())
for entry in (azure.containers or [])
]
)
def get_gs_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
json_string: bool = False,
) -> GSBucketConfigurations:
return copy(self._default_gs_configs)
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("google").first()
)
if not db_settings or not db_settings.google:
if not json_string:
return copy(self._default_gs_configs)
if self._default_gs_configs._buckets:
buckets = [
attr.evolve(
b,
credentials_json=self._assure_json_string(b.credentials_json),
)
for b in self._default_gs_configs._buckets
]
else:
buckets = self._default_gs_configs._buckets
return GSBucketConfigurations(
buckets=buckets,
default_project=self._default_gs_configs._default_project,
default_credentials=self._assure_json_string(
self._default_gs_configs._default_credentials
),
)
def get_bucket_config(bc: GoogleBucketSettings) -> GSBucketConfig:
data = bc.to_proper_dict()
if not json_string and bc.credentials_json:
data["credentials_json"] = self._assure_json_file(bc.credentials_json)
return GSBucketConfig(**data)
google = db_settings.google
buckets_configs = [get_bucket_config(b) for b in (google.buckets or [])]
return GSBucketConfigurations(
buckets=buckets_configs,
default_project=google.project,
default_credentials=google.credentials_json
if json_string
else self._assure_json_file(google.credentials_json),
)
def get_aws_settings_for_company(
self,
company_id: str,
db_settings: StorageSettings = None,
query_db: bool = True,
) -> S3BucketConfigurations:
return copy(self._default_aws_configs)
if not db_settings and query_db:
db_settings = (
StorageSettings.objects(company=company_id).only("aws").first()
)
if not db_settings or not db_settings.aws:
return copy(self._default_aws_configs)
aws = db_settings.aws
buckets_configs = S3BucketConfig.from_list(
[b.to_proper_dict() for b in (aws.buckets or [])]
)
return S3BucketConfigurations(
buckets=buckets_configs,
default_key=aws.key,
default_secret=aws.secret,
default_region=aws.region,
default_use_credentials_chain=aws.use_credentials_chain,
default_token=aws.token,
default_extra_args={},
)
def _assure_json_file(self, name_or_content: str) -> str:
if not name_or_content:
return name_or_content
if name_or_content.endswith(".json") or os.path.exists(name_or_content):
return name_or_content
try:
json.loads(name_or_content)
except Exception:
return name_or_content
with tempfile.NamedTemporaryFile(
mode="wt", delete=False, suffix=".json"
) as tmp:
tmp.write(name_or_content)
return tmp.name
def _assure_json_string(self, name_or_content: str) -> Optional[str]:
if not name_or_content:
return name_or_content
try:
json.loads(name_or_content)
return name_or_content
except Exception:
pass
try:
with open(name_or_content) as fp:
return fp.read()
except Exception:
return None
def get_company_settings(self, company_id: str) -> dict:
db_settings = StorageSettings.objects(company=company_id).first()
aws = self.get_aws_settings_for_company(company_id, db_settings, query_db=False)
aws_dict = {
"key": aws._default_key,
"secret": aws._default_secret,
"token": aws._default_token,
"region": aws._default_region,
"use_credentials_chain": aws._default_use_credentials_chain,
"buckets": [attr.asdict(b) for b in aws._buckets],
}
gs = self.get_gs_settings_for_company(
company_id, db_settings, query_db=False, json_string=True
)
gs_dict = {
"project": gs._default_project,
"credentials_json": gs._default_credentials or None,
"buckets": [attr.asdict(b) for b in gs._buckets],
}
azure = self.get_azure_settings_for_company(company_id, db_settings)
azure_dict = {
"containers": [attr.asdict(ac) for ac in azure._container_configs],
}
return {
"aws": aws_dict,
"google": gs_dict,
"azure": azure_dict,
"last_update": db_settings.last_update if db_settings else None,
}
def set_company_settings(
self, company_id: str, settings: SetSettingsRequest
) -> int:
update_dict = {}
if settings.aws:
update_dict["aws"] = {
**{
k: v
for k, v in settings.aws.to_struct().items()
if k in AWSSettings.get_fields()
}
}
if settings.azure:
update_dict["azure"] = {
**{
k: v
for k, v in settings.azure.to_struct().items()
if k in AzureStorageSettings.get_fields()
}
}
if settings.google:
update_dict["google"] = {
**{
k: v
for k, v in settings.google.to_struct().items()
if k in GoogleStorageSettings.get_fields()
}
}
cred_json = update_dict["google"].get("credentials_json")
if cred_json:
try:
json.loads(cred_json)
except Exception as ex:
raise errors.bad_request.ValidationError(
f"Invalid json credentials: {str(ex)}"
)
if not update_dict:
raise errors.bad_request.ValidationError("No settings were provided")
settings = StorageSettings.objects(company=company_id).only("id").first()
settings_id = settings.id if settings else db_id()
return StorageSettings.objects(id=settings_id).update(
upsert=True,
id=settings_id,
company=company_id,
last_update=datetime.utcnow(),
**update_dict,
)
def reset_company_settings(self, company_id: str, keys: Sequence[str]) -> int:
return StorageSettings.objects(company=company_id).update(
last_update=datetime.utcnow(), **{f"unset__{k}": 1 for k in keys}
)

View File

@@ -1,6 +1,5 @@
from .task_bll import TaskBLL
from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
)

View File

@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -48,12 +49,14 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifacts = {
get_artifact_id(a): Artifact(**a)
@@ -64,18 +67,20 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifact_ids = [
get_artifact_id(a)
@@ -85,4 +90,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.service_repo.auth import Identity
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -31,7 +32,10 @@ class HyperParams:
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -63,7 +67,7 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
@@ -74,6 +78,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
with_param, without_param = iterutils.partition(
@@ -96,7 +101,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@@ -105,7 +110,7 @@ class HyperParams:
def edit_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
@@ -117,6 +122,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
update_cmds = dict()
@@ -135,7 +141,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@@ -163,7 +169,10 @@ class HyperParams:
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -184,7 +193,7 @@ class HyperParams:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"company": {"$in": ["", company_id]},
"_id": {"$in": task_ids},
}
},
@@ -209,13 +218,15 @@ class HyperParams:
def edit_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
update_cmds = dict()
configuration = {
@@ -228,22 +239,24 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -1,7 +1,7 @@
from datetime import timedelta, datetime
from time import sleep
from apiserver.bll.task import update_project_time
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.task.task import TaskStatus, Task
from apiserver.utilities.threads_manager import ThreadsManager
@@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
status_changed=now,
last_update=now,
last_change=now,
last_changed_by="__apiserver__",
)
if updated:
project_ids.add(task.project)

View File

@@ -7,11 +7,12 @@ from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.apierrors import errors, APIError
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -30,16 +31,21 @@ from apiserver.database.model.task.task import (
TaskModelTypes,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.database.model.queue import Queue
from apiserver.database.utils import (
get_company_or_none_constraint,
id as create_id,
)
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.utilities.dicts import nested_set
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
get_last_metric_updates,
)
log = config.logger(__file__)
@@ -53,30 +59,13 @@ class TaskBLL:
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
@@ -175,18 +164,36 @@ class TaskBLL:
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
hyperparams_overrides: Optional[dict] = None,
configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
task: Task = cls.get_by_id(
company_id=company_id, task_id=task_id, allow_public=True
)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
params_dict = {}
if hyperparams:
params_dict["hyperparams"] = hyperparams
elif hyperparams_overrides:
updated_hyperparams = {
sec: {k: value for k, value in sec_data.items()}
for sec, sec_data in (task.hyperparams or {}).items()
}
for section, section_data in hyperparams_overrides.items():
for key, value in section_data.items():
nested_set(updated_hyperparams, (section, key), value)
params_dict["hyperparams"] = updated_hyperparams
if configuration:
params_dict["configuration"] = configuration
elif configuration_overrides:
updated_configuration = {
k: value for k, value in (task.configuration or {}).items()
}
for key, value in configuration_overrides.items():
updated_configuration[key] = value
params_dict["configuration"] = updated_configuration
now = datetime.utcnow()
if input_models:
@@ -256,6 +263,16 @@ class TaskBLL:
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
def ensure_int_labels(execution: dict) -> dict:
if not execution:
return execution
model_labels = execution.get("model_labels")
if model_labels:
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
return execution
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
@@ -280,7 +297,7 @@ class TaskBLL:
output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
execution=ensure_int_labels(execution_dict),
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
@@ -301,7 +318,7 @@ class TaskBLL:
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
projects=[new_task.project],
tags=updated_tags,
system_tags=updated_system_tags,
)
@@ -344,6 +361,7 @@ class TaskBLL:
def set_last_update(
task_ids: Collection[str],
company_id: str,
user_id: str,
last_update: datetime,
**extra_updates,
):
@@ -364,6 +382,7 @@ class TaskBLL:
upsert=False,
last_update=last_update,
last_change=last_update,
last_changed_by=user_id,
**updates,
)
return count
@@ -372,6 +391,7 @@ class TaskBLL:
def update_statistics(
task_id: str,
company_id: str,
user_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
@@ -388,7 +408,7 @@ class TaskBLL:
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_scalar_events: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
@@ -402,81 +422,12 @@ class TaskBLL:
raw_updates = {}
if last_scalar_events is not None:
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values-to_add}__exists"] = True
task = Task.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [
condition,
iter_value,
f"${value_iteration_field}",
]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = (
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
)
if max_values:
if (
len(total_metrics) >= max_values
and metric not in total_metrics
):
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
get_last_metric_updates(
task_id=task_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=extra_updates,
)
if last_events is not None:
@@ -497,6 +448,7 @@ class TaskBLL:
ret = TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
user_id=user_id,
last_update=last_update,
**extra_updates,
)
@@ -505,6 +457,17 @@ class TaskBLL:
return ret
@staticmethod
def remove_task_from_all_queues(
company_id: str, task_id: str, exclude: str = None
) -> int:
more = {}
if exclude:
more["id__ne"] = exclude
return Queue.objects(company=company_id, entries__task=task_id, **more).update(
pull__entries__task=task_id, last_update=datetime.utcnow()
)
@classmethod
def dequeue_and_change_status(
cls,
@@ -513,23 +476,36 @@ class TaskBLL:
user_id: str,
status_message: str,
status_reason: str,
remove_from_all_queues=False,
new_status=None,
new_status_for_aborted_task=None,
):
try:
cls.dequeue(task, company_id)
except errors.bad_request.InvalidQueueOrTaskNotQueued:
cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True)
except APIError:
# dequeue may fail if the queue was deleted
pass
if remove_from_all_queues:
cls.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
if task.status not in [TaskStatus.queued, TaskStatus.in_progress]:
return {"updated": 0}
if new_status_for_aborted_task and task.status == TaskStatus.in_progress:
new_status = new_status_for_aborted_task
return ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
new_status=new_status or task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
force=True,
).execute(enqueue_status=None)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
@@ -556,6 +532,9 @@ class TaskBLL:
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
company_id=company_id,
user_id=user_id,
queue_id=task.execution.queue,
task_id=task.id,
)
}

View File

@@ -1,10 +1,10 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Set, Tuple
from typing import Sequence, Set, Tuple, Union
import attr
from boltons.iterutils import partition, bucketize, first
from boltons.iterutils import partition, bucketize, first, chunked_iter
from furl import furl
from mongoengine import NotUniqueError
from pymongo.errors import DuplicateKeyError
@@ -26,14 +26,13 @@ from apiserver.database.utils import id as db_id
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
event_urls: Sequence[str]
artifact_urls: Sequence[str]
event_urls: Sequence[str] = [] # left here is in order not to break the api
def __add__(self, other: "TaskUrls"):
if not other:
@@ -41,7 +40,6 @@ class TaskUrls:
return TaskUrls(
model_urls=list(set(self.model_urls) | set(other.model_urls)),
event_urls=list(set(self.event_urls) | set(other.event_urls)),
artifact_urls=list(set(self.artifact_urls) | set(other.artifact_urls)),
)
@@ -55,8 +53,23 @@ class CleanupResult:
updated_children: int
updated_models: int
deleted_models: int
deleted_model_ids: Set[str]
urls: TaskUrls = None
def to_res_dict(self, return_file_urls: bool) -> dict:
remove_fields = ["deleted_model_ids"]
if not return_file_urls:
remove_fields.append("urls")
# noinspection PyTypeChecker
res = attr.asdict(
self, filter=lambda attrib, value: attrib.name not in remove_fields
)
if not return_file_urls:
res["urls"] = None
return res
def __add__(self, other: "CleanupResult"):
if not other:
return self
@@ -66,55 +79,87 @@ class CleanupResult:
updated_models=self.updated_models + other.updated_models,
deleted_models=self.deleted_models + other.deleted_models,
urls=self.urls + other.urls if self.urls else other.urls,
deleted_model_ids=self.deleted_model_ids | other.deleted_model_ids,
)
@staticmethod
def empty():
return CleanupResult(
updated_children=0,
updated_models=0,
deleted_models=0,
deleted_model_ids=set(),
)
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
def collect_plot_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
urls = set()
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_ids=tasks, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
def collect_debug_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
after_key = None
urls = set()
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company, task_id=task_or_model, after_key=after_key,
)
urls.update(res)
if not after_key:
break
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
after_key = None
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_ids=tasks,
after_key=after_key,
)
urls.update(res)
if not after_key:
break
return urls
supported_storage_types = {
"https://": StorageType.fileserver,
"http://": StorageType.fileserver,
"s3://": StorageType.s3,
"azure://": StorageType.azure,
"gs://": StorageType.gs,
}
supported_storage_types.update(
{
p: StorageType.fileserver
for p in config.get(
"services.async_urls_delete.fileserver.url_prefixes",
["https://", "http://"],
)
}
)
def _schedule_for_delete(
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
def schedule_for_delete(
company: str,
user: str,
task_id: str,
urls: Set[str],
can_delete_folders: bool,
) -> Set[str]:
urls_per_storage = bucketize(
urls,
@@ -176,15 +221,27 @@ def _schedule_for_delete(
return processed_urls
def delete_task_events_and_collect_urls(
company: str, task_ids: Sequence[str], wait_for_delete: bool, model=False
) -> Set[str]:
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
company, task_ids
)
event_bll.delete_task_events(
company, task_ids, model=model, wait_for_delete=wait_for_delete
)
return event_urls
def cleanup_task(
company: str,
user: str,
task: Task,
force: bool = False,
update_children=True,
return_file_urls=False,
delete_output_models=True,
delete_external_artifacts=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
@@ -195,88 +252,69 @@ def cleanup_task(
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls or delete_external_artifacts:
event_urls = collect_debug_image_urls(task.company, task.id)
event_urls.update(collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls = {
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
artifact_urls = (
{
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
if task.execution and task.execution.artifacts
else {}
)
model_urls = {m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids}
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
now = datetime.utcnow()
if update_children:
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id,
last_change=now,
last_changed_by=user,
)
deleted_models = 0
updated_models = 0
deleted_model_ids = set()
for models, allow_delete in ((draft_models, True), (published_models, False)):
if not models:
continue
if delete_output_models and allow_delete:
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
for m_id in model_ids:
if return_file_urls or delete_external_artifacts:
event_urls.update(collect_debug_image_urls(task.company, m_id))
event_urls.update(collect_plot_image_urls(task.company, m_id))
try:
event_bll.delete_task_events(
task.company,
m_id,
allow_locked=True,
model=True,
async_delete=async_events_delete,
)
except errors.bad_request.InvalidModelId as ex:
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
if model_ids:
deleted_models += Model.objects(id__in=model_ids).delete()
deleted_model_ids.update(model_ids)
deleted_models += Model.objects(id__in=list(model_ids)).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
Model.objects(id__in=list(in_use_model_ids)).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
continue
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id
task=deleted_task_id,
last_change=now,
last_changed_by=user,
)
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
event_bll.delete_task_events(
task.company, task.id, allow_locked=force, async_delete=async_events_delete
)
if delete_external_artifacts:
scheduled = _schedule_for_delete(
task_id=task.id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=not in_use_model_ids and not published_models,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
Model.objects(id__in=[m.id for m in models]).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
)
if return_file_urls
else None,
),
deleted_model_ids=deleted_model_ids,
)
@@ -296,7 +334,8 @@ def verify_task_children_and_ouptuts(
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
Model.objects(task=task.id).only(*model_fields),
key=attrgetter("ready"),
)
if not force and published_models:
raise errors.bad_request.TaskCannotBeDeleted(

View File

@@ -7,9 +7,10 @@ from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
@@ -22,88 +23,170 @@ from apiserver.database.model.task.task import (
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
queue_bll = QueueBLL()
def _get_pipeline_steps_for_controller_task(
task: Task, company_id: str, only: Sequence[str] = None
) -> Sequence[Task]:
if not task or task.type != TaskType.controller:
return []
query = Task.objects(company=company_id, parent=task.id)
if only:
query = query.only(*only)
return list(query)
def archive_task(
task: Union[str, Task],
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
user_id = identity.user
fields = (
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
"type",
)
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task,
company_id=company_id,
only=(
"id",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
requires_write_access=True,
identity=identity,
only=fields,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
def archive_task_core(task_: Task) -> int:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
new_status_for_aborted_task=TaskStatus.stopped,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task_.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
archive_task_core(step)
return archive_task_core(task)
def unarchive_task(
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
task_id: str,
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task, company_id=company_id, only=("id",), requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
fields = ("id", "type")
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=fields,
)
def unarchive_task_core(task_: Task) -> int:
return task_.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
unarchive_task_core(step)
return unarchive_task_core(task)
def dequeue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
remove_from_all_queues: bool = False,
new_status=None,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if new_status and new_status not in get_options(TaskStatus):
raise errors.bad_request.ValidationError(f"Invalid task status: {new_status}")
# get the task without write access to make sure that it actually exists
task = Task.get(
id=task_id,
company=company_id,
_only=("id",),
include_public=True,
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"enqueue_status",
),
)
res = TaskBLL.dequeue_and_change_status(
task,
@@ -111,6 +194,8 @@ def dequeue_task(
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=remove_from_all_queues,
new_status=new_status,
)
return 1, res
@@ -118,19 +203,32 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
queue_id: str,
status_message: str,
status_reason: str,
queue_name: str = None,
validate: bool = False,
force: bool = False,
update_execution_queue: bool = True,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
if not update_execution_queue:
if not (
task.status == TaskStatus.queued and task.execution and task.execution.queue
):
raise errors.bad_request.ValidationError(
"Cannot skip setting execution queue for a task "
"that is not enqueued or does not have execution queue set"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
@@ -143,23 +241,21 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
user_id = identity.user
if validate:
TaskBLL.validate(task)
before_enqueue_status = task.status
if task.status == TaskStatus.queued and task.enqueue_status:
before_enqueue_status = task.enqueue_status
res = ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
force=force,
user_id=user_id,
).execute(enqueue_status=task.status)
).execute(enqueue_status=before_enqueue_status)
try:
queue_bll.add_task(company_id=company_id, queue_id=queue_id, task_id=task.id)
@@ -176,12 +272,19 @@ def enqueue_task(
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
if update_execution_queue:
if task.execution:
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(id=task_id).update(
execution=Execution(queue=queue_id), multi=False
)
nested_set(res, ("fields", "execution.queue"), queue_id)
nested_set(res, ("fields", "execution.queue"), queue_id)
# make sure that the task is not queued in any other queue
TaskBLL.remove_task_from_all_queues(
company_id=company_id, task_id=task_id, exclude=queue_id
)
return 1, res
@@ -212,18 +315,16 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
def delete_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
user_id = identity.user
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if (
task.status != TaskStatus.created
@@ -237,34 +338,50 @@ def delete_task(
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
def delete_task_core(task_: Task, force_: bool) -> CleanupResult:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
res = cleanup_task(
company=company_id,
user=user_id,
task=task_,
force=force_,
delete_output_models=delete_output_models,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
company=company_id,
user=user_id,
task=task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task_.last_update = datetime.utcnow()
task_.save()
else:
task_.delete()
return res
task_ids = [task.id]
cleanup_res = CleanupResult.empty()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
cleanup_res += delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.save()
else:
task.delete()
move_tasks_to_trash(task_ids)
update_project_time(task.project)
return 1, task, cleanup_res
@@ -273,16 +390,13 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
user_id = identity.user
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@@ -291,20 +405,22 @@ def reset_task(
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
dequeued = TaskBLL.dequeue(
task, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
TaskBLL.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
cleaned_up = cleanup_task(
company=company_id,
user=user_id,
task=task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
updates.update(
@@ -318,11 +434,17 @@ def reset_task(
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
unset__started=1,
unset__completed=1,
unset__published=1,
unset__active_duration=1,
unset__enqueue_status=1,
)
if clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
@@ -343,11 +465,6 @@ def reset_task(
status_message="reset",
user_id=user_id,
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
enqueue_status=None,
**updates,
)
@@ -357,15 +474,14 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
publish_model_func: Callable[[str, str, str], Any] = None,
publish_model_func: Callable[[str, str, Identity], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
user_id = identity.user
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -387,7 +503,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id, user_id)
publish_model_func(model.id, company_id, identity)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@@ -411,10 +527,11 @@ def publish_task(
def stop_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
user_name: str,
status_reason: str,
force: bool,
include_pipeline_steps: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
@@ -424,20 +541,22 @@ def stop_task(
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
user_id = identity.user
fields = (
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
"type",
)
task = get_task_with_write_access(
task_id,
company_id=company_id,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
requires_write_access=True,
identity=identity,
only=fields,
)
def is_run_by_worker(t: Task) -> bool:
@@ -449,32 +568,45 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
def stop_task_core(task_: Task, force_: bool):
is_queued = task_.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task_.system_tags
or not is_run_by_worker(task_)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(
task_, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task_.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute()
return ChangeStatusRequest(
task=task_,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force_,
user_id=user_id,
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
stop_task_core(step, True)
return stop_task_core(task, force)

View File

@@ -1,14 +1,19 @@
from datetime import datetime
from typing import Sequence, Union
from typing import Sequence
import attr
import six
from mongoengine import Q
from mongoengine.base import UPDATE_OPERATORS
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -74,8 +79,16 @@ class ChangeStatusRequest(object):
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
def is_mongo_operator(field: str) -> bool:
head, _, tail = field.partition("__")
return tail and (head in UPDATE_OPERATORS)
# make sure to not return _raw_ queries or any of the update operators
fields = {
key: value
for key, value in fields.items()
if not (key == "__raw__" or is_mongo_operator(key))
}
return dict(updated=updated, fields=fields)
@@ -131,7 +144,12 @@ state_machine = {
TaskStatus.publishing,
TaskStatus.stopped,
},
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
TaskStatus.failed: {
TaskStatus.created,
TaskStatus.stopped,
TaskStatus.published,
TaskStatus.queued,
},
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
TaskStatus.completed: {
@@ -156,25 +174,78 @@ def get_possible_status_changes(current_status):
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
def get_many_tasks_for_writing(
company_id: str,
identity: Identity,
query: Q = None,
only: Sequence = None,
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company",) if f not in only]
if missing:
only = [*only, *missing]
if isinstance(project_ids, str):
project_ids = [project_ids]
result = list(
Task.get_many(
company=company_id,
query=query,
override_projection=only,
allow_public=True,
return_dicts=False,
)
)
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
if not company_id:
return result
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:
raise errors.forbidden.NoWritePermission(
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
)
result = [task for task in result if task.id not in forbidden_tasks]
return result
def get_task_with_write_access(
task_id: str,
company_id: str,
identity: Identity,
only=None,
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(_only=only, **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str,
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False,
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
only=("id", "status"),
identity=identity,
)
if allow_all_statuses:
return task
@@ -189,9 +260,152 @@ def get_task_for_update(
return task
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
def update_task(
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
):
now = datetime.utcnow()
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)
def get_last_metric_updates(
task_id: str,
last_scalar_events: dict,
raw_updates: dict,
extra_updates: dict,
model_events: bool = False,
):
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values - to_add}__exists"] = True
db_cls = Model if model_events else Task
task = db_cls.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_mean_update(
metric_path: str,
metric_count: int,
metric_total: float,
):
"""
Update new mean field based on the value in db and new data
The count field is updated here too and not with inc__ so that
it will not get updated in the db earlier than the corresponding mean
"""
metric_path = metric_path.replace("__", ".")
mean_value_field = f"{metric_path}.mean_value"
count_field = f"{metric_path}.count"
raw_updates[mean_value_field] = {
"$round": [
{
"$divide": [
{
"$add": [
{
"$multiply": [
{"$ifNull": [f"${mean_value_field}", 0]},
{"$ifNull": [f"${count_field}", 0]},
]
},
metric_total,
]
},
{
"$add": [
{"$ifNull": [f"${count_field}", 0]},
metric_count,
]
},
]
},
2,
]
}
raw_updates[count_field] = {
"$add": [
{"$ifNull": [f"${count_field}", 0]},
metric_count,
]
}
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool, is_first: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_first:
field_prefix = "first"
op = None
elif is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
exists = {"$lte": [f"${value_field}", None]}
if op:
condition = {
"$or": [
exists,
{op: [f"${value_field}", metric_value]},
]
}
else:
condition = exists
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = (
f"{metric_path}__{field_prefix}_value_iteration".replace("__", ".")
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
if max_values:
if len(total_metrics) >= max_values and metric not in total_metrics:
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value", "first_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
is_first=(key == "first_value"),
)
elif key in ("metric", "variant", "value", "x_axis_label"):
extra_updates[f"set__{path}__{key}"] = value
count = variant_data.get("count")
total = variant_data.get("total")
if count is not None and total is not None:
add_last_metric_mean_update(
metric_path=path,
metric_count=count,
metric_total=total,
)
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from apiserver.apierrors import errors
from apiserver.apimodels.users import CreateRequest
from apiserver.config.info import get_version
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.user import User
@@ -14,7 +15,11 @@ class UserBLL:
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
user = User(**request.to_struct(), created=datetime.utcnow())
user = User(
**request.to_struct(),
created=datetime.utcnow(),
created_in_version=get_version(),
)
user.save(force_insert=True)
@staticmethod

View File

@@ -1,76 +1,24 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from typing import (
Optional,
Callable,
Dict,
Any,
Set,
Iterable,
Tuple,
Sequence,
TypeVar,
Union,
)
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model import AttributedDocument
from apiserver.database.model.project import Project
from apiserver.database.model.settings import Settings
class SetFieldsResolver:
"""
The class receives set fields dictionary
and for the set fields that require 'min' or 'max'
operation replace them with a simple set in case the
DB document does not have these fields set
"""
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = {}
self.fields = {}
self.add_fields(**set_fields)
def add_fields(self, **set_fields: Any):
self.orig_fields.update(set_fields)
self.fields.update(
{
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
return self.fields[name]
return name
def get_fields(self, doc: AttributedDocument):
"""
For the given document return the set fields instructions
with min/max operations replaced with a single set in case
the document does not have the field set
"""
return {
self._get_updated_name(doc, name): value
for name, value in self.orig_fields.items()
}
def get_names(self) -> Set[str]:
"""
Returns the names of the fields that had min/max modifiers
in the format suitable for projection (dot separated)
"""
return set(name.replace("__", ".") for name in self.fields.values())
@functools.lru_cache()
def get_server_uuid() -> Optional[str]:
return Settings.get_by_key("server.uuid")
@@ -132,3 +80,13 @@ def run_batch_operation(
}
)
return results, failures
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())

View File

@@ -1,17 +1,18 @@
import itertools
import re
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from boltons.iterutils import partition, chunked_iter
from pyhocon import ConfigTree
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
from apiserver.apierrors.errors import bad_request, server_error
from apiserver.apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry,
WorkerEntry,
StatusReportRequest,
@@ -27,15 +28,18 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
_key_regex_trans = str.maketrans({"*": ".*", "?": ".?"})
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@@ -68,7 +72,7 @@ class WorkerBLL:
"""
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
queues = queues or []
with translate_errors_context():
@@ -141,8 +145,6 @@ class WorkerBLL:
try:
entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
@@ -150,15 +152,16 @@ class WorkerBLL:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
self.log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=entry.key,
worker_id=report.worker,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
)
now = datetime.utcnow()
entry.last_activity_time = now
entry.queue = report.queue
if report.queues:
@@ -175,6 +178,7 @@ class WorkerBLL:
last_worker_report=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)
@@ -200,12 +204,41 @@ class WorkerBLL:
finally:
self._save_worker(entry)
def get_count(
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
):
if not last_seen:
return len(
self._get_keys(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
return len(
self.get_all(
company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
def get_all(
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -214,7 +247,12 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
workers = self._get(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@@ -234,19 +272,18 @@ class WorkerBLL:
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
for entry in self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
)
]
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set(
@@ -260,19 +297,18 @@ class WorkerBLL:
{
"$project": {
"name": 1,
"display_name": 1,
"next_entry": {"$arrayElemAt": ["$entries", 0]},
"num_entries": {"$size": "$entries"},
}
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
task_ids = task_ids.union(
filter(
None,
(
safe_get(info, "next_entry/task")
nested_get(info, ("next_entry", "task"))
for info in queues_info.values()
),
)
@@ -295,8 +331,9 @@ class WorkerBLL:
if not info:
continue
entry.name = info.get("name", None)
entry.display_name = info.get("display_name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = safe_get(info, "next_entry/task")
task_id = nested_get(info, ("next_entry", "task"))
if task_id:
task = tasks_info.get(task_id, None)
entry.next_task = IdNameEntry(
@@ -306,7 +343,7 @@ class WorkerBLL:
for helper in helpers:
worker = helper.worker
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
task: Task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
@@ -396,83 +433,110 @@ class WorkerBLL:
msg = "Failed saving worker entry"
log.exception(msg)
def _get_keys(
self,
company: str,
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[bytes]:
if not (user_tags or system_tags):
match = self._get_worker_key(company, user, worker_pattern or "*")
return list(self.redis.scan_iter(match))
def filter_by_user_and_pattern(in_keys: Set[bytes]) -> Set[bytes]:
if user != "*":
user_bytes = user.encode()
in_keys = {k for k in in_keys if user_bytes in k}
if worker_pattern:
worker_pattern_bytes = (
f"{worker_pattern.translate(self._key_regex_trans)}$".encode()
)
regex = re.compile(worker_pattern_bytes)
in_keys = {k for k in in_keys if regex.search(k)}
return in_keys
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
tagged_workers = filter_by_user_and_pattern(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user_and_pattern(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
return list(worker_keys)
def _get(
self,
company: str,
user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags:
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(
all_workers_key, min=0, max=timestamp
)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
entries = []
for key in worker_keys:
data = self.redis.get(key)
for keys in chunked_iter(
self._get_keys(
company,
user=user,
user_tags=user_tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
),
1000,
):
data = self.redis.mget(keys)
if data:
entries.append(WorkerEntry.from_json(data))
entries.extend(WorkerEntry.from_json(d) for d in data if d)
return entries
@@ -481,18 +545,17 @@ class WorkerBLL:
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es(
def log_stats_to_es(
self,
company_id: str,
company_name: str,
worker: str,
worker_id: str,
timestamp: int,
task: str,
machine_stats: MachineStats,
) -> bool:
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise
:return: The amount of logged documents
"""
es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
@@ -504,8 +567,7 @@ class WorkerBLL:
_index=es_index,
_source=dict(
timestamp=timestamp,
worker=worker,
company=company_name,
worker=worker_id,
task=task,
category=category,
metric=metric,
@@ -530,7 +592,7 @@ class WorkerBLL:
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return added
@attr.s(auto_attribs=True)

View File

@@ -1,8 +1,9 @@
from operator import attrgetter
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from typing import Optional, Sequence
from boltons.iterutils import bucketize
from apiserver.apierrors import errors
from apiserver.apierrors.errors import bad_request
from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatItem
from apiserver.bll.query import Builder as QueryBuilder
@@ -13,6 +14,9 @@ log = config.logger(__file__)
class WorkerStats:
min_chart_interval = config.get("services.workers.min_chart_interval_sec", 40)
_max_metrics_concurrency = config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
def __init__(self, es):
self.es = es
@@ -21,7 +25,7 @@ class WorkerStats:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id.lower()}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
def search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
body=es_req,
@@ -49,7 +53,7 @@ class WorkerStats:
if worker_ids:
es_req["query"] = QueryBuilder.terms("worker", worker_ids)
res = self._search_company_stats(company_id, es_req)
res = self.search_company_stats(company_id, es_req)
if not res["hits"]["total"]["value"]:
raise bad_request.WorkerStatsNotFound(
@@ -63,6 +67,75 @@ class WorkerStats:
for category in res["aggregations"]["categories"]["buckets"]
}
def _get_worker_stats_per_metric(
self,
metric_item: StatItem,
company_id: str,
from_date: float,
to_date: float,
interval: int,
split_by_resource: bool,
worker_ids: Sequence[str],
):
agg_types_to_es = {
AggregationType.avg: "avg",
AggregationType.min: "min",
AggregationType.max: "max",
}
agg = {
metric_item.aggregation.value: {
agg_types_to_es[metric_item.aggregation]: {"field": "value", "missing": 0.0 }
}
}
split_by_resource = split_by_resource and metric_item.key.startswith("gpu_")
if split_by_resource:
split_aggs = {"split": {"terms": {"field": "variant"}, "aggs": agg}}
else:
split_aggs = {}
es_req = {
"size": 0,
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
},
},
"aggs": {
**agg,
**split_aggs,
},
}
},
}
},
}
query_terms = [
QueryBuilder.dates_range(from_date, to_date),
QueryBuilder.term("metric", metric_item.key),
]
if worker_ids:
query_terms.append(QueryBuilder.terms("worker", worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context():
data = self.search_company_stats(company_id, es_req)
cutoff_date = (
to_date - 0.9 * interval
) * 1000 # do not return the point for the incomplete last interval
return self._extract_results(
data, metric_item, split_by_resource, cutoff_date
)
def get_worker_stats(self, company_id: str, request: GetStatsRequest) -> dict:
"""
Get statistics for company workers metrics in the specified time range
@@ -71,119 +144,93 @@ class WorkerStats:
Buckets with no metrics are not returned
Note: all the statistics are retrieved as one ES query
"""
if request.from_date >= request.to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
def get_dates_agg() -> dict:
es_to_agg_types = (
("avg", AggregationType.avg.value),
("min", AggregationType.min.value),
("max", AggregationType.max.value),
from_date = request.from_date
to_date = request.to_date
if from_date >= to_date:
raise errors.bad_request.FieldsValueError(
"from_date must be less than to_date"
)
return {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{request.interval}s",
"min_doc_count": 1,
},
"aggs": {
agg_type: {es_agg: {"field": "value"}}
for es_agg, agg_type in es_to_agg_types
},
}
}
interval = max(request.interval, self.min_chart_interval)
with ThreadPoolExecutor(self._max_metrics_concurrency) as pool:
res = list(
pool.map(
partial(
self._get_worker_stats_per_metric,
company_id=company_id,
from_date=from_date,
to_date=to_date,
interval=interval,
split_by_resource=request.split_by_resource,
worker_ids=request.worker_ids,
),
request.items,
)
)
def get_variants_agg() -> dict:
return {
"variants": {"terms": {"field": "variant"}, "aggs": get_dates_agg()}
}
ret = defaultdict(lambda: defaultdict(dict))
for workers in res:
for worker, metrics in workers.items():
for metric, stats in metrics.items():
ret[worker][metric].update(stats)
es_req = {
"size": 0,
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"aggs": get_variants_agg()
if request.split_by_variant
else get_dates_agg(),
}
},
}
},
}
query_terms = [
QueryBuilder.dates_range(request.from_date, request.to_date),
QueryBuilder.terms("metric", {item.key for item in request.items}),
]
if request.worker_ids:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
return ret
@staticmethod
def _extract_results(
data: dict, request_items: Sequence[StatItem], split_by_variant: bool
data: dict,
metric_item: StatItem,
split_by_resource: bool,
cutoff_date,
) -> dict:
"""
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
leave only aggregation types requested by the user and return a clean dictionary
and return a "clean" dictionary of
:param data: aggregation data retrieved from ES
:param request_items: aggs types requested by the user
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
"""
if "aggregations" not in data:
return {}
items_by_key = bucketize(request_items, key=attrgetter("key"))
aggs_per_metric = {
key: [item.aggregation for item in items]
for key, items in items_by_key.items()
}
def extract_metric_results(metric: dict) -> dict:
aggregation = metric_item.aggregation.value
date_buckets = metric["dates"]["buckets"]
length = len(date_buckets)
while length > 0 and date_buckets[length - 1]["key"] >= cutoff_date:
length -= 1
dates = [None] * length
agg_values = [0.0] * length
resource_series = defaultdict(lambda: [0.0] * length)
for idx in range(0, length):
date = date_buckets[idx]
dates[idx] = date["key"]
if aggregation in date:
agg_values[idx] = date[aggregation]["value"] or 0.0
if split_by_resource and "split" in date:
for resource in date["split"]["buckets"]:
series = resource_series[resource["key"]]
if aggregation in resource:
series[idx] = resource[aggregation]["value"] or 0.0
if len(resource_series) == 1:
resource_series = {}
def extract_date_stats(date: dict, metric_key) -> dict:
return {
"date": date["key"],
"count": date["doc_count"],
**{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
}
def extract_metric_results(
metric_or_variant: dict, metric_key: str
) -> Sequence[dict]:
return [
extract_date_stats(date, metric_key)
for date in metric_or_variant["dates"]["buckets"]
if date["doc_count"]
]
def extract_variant_results(metric: dict) -> dict:
metric_key = metric["key"]
return {
variant["key"]: extract_metric_results(variant, metric_key)
for variant in metric["variants"]["buckets"]
}
def extract_worker_results(worker: dict) -> dict:
return {
metric["key"]: extract_variant_results(metric)
if split_by_variant
else extract_metric_results(metric, metric["key"])
for metric in worker["metrics"]["buckets"]
"dates": dates,
"values": agg_values,
**(
{"resource_series": resource_series} if resource_series else {}
),
}
return {
worker["key"]: extract_worker_results(worker)
worker["key"]: {
metric_item.key: {
metric_item.aggregation.value: extract_metric_results(worker)
}
}
for worker in data["aggregations"]["workers"]["buckets"]
}
@@ -203,6 +250,7 @@ class WorkerStats:
"""
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(interval, self.min_chart_interval)
must = [QueryBuilder.dates_range(from_date, to_date)]
if active_only:
@@ -215,6 +263,10 @@ class WorkerStats:
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}
@@ -223,7 +275,7 @@ class WorkerStats:
}
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
data = self.search_company_stats(company_id, es_req)
if "aggregations" not in data:
return {}

View File

@@ -6,7 +6,7 @@ from functools import reduce
from os import getenv
from os.path import expandvars
from pathlib import Path
from typing import List, Any, TypeVar, Sequence
from typing import List, Any, TypeVar, Sequence, Set
from boltons.iterutils import first
from pyhocon import ConfigTree, ConfigFactory, ConfigValues
@@ -35,6 +35,7 @@ class BasicConfig:
folder: str = None,
verbose: bool = True,
prefix: Sequence[str] = DEFAULT_PREFIXES,
exclude_files_from_base_folder: Sequence[str] = None,
):
folder = (
Path(folder)
@@ -44,6 +45,11 @@ class BasicConfig:
if not folder.is_dir():
raise ValueError("Invalid configuration folder")
self.exclude_files_from_base_folder = (
set(exclude_files_from_base_folder)
if exclude_files_from_base_folder
else set()
)
self.verbose = verbose
self.extra_config_path_override_var = [
@@ -85,7 +91,7 @@ class BasicConfig:
return logging.getLogger(path)
def _read_extra_env_config_values(self) -> ConfigTree:
""" Loads extra configuration from environment-injected values """
"""Loads extra configuration from environment-injected values"""
result = ConfigTree()
for prefix in self.extra_config_values_env_key_prefix:
@@ -125,12 +131,18 @@ class BasicConfig:
def _reload(self) -> ConfigTree:
extra_config_values = self._read_extra_env_config_values()
configs = [self._read_recursive(path) for path in self._paths]
configs = [
self._read_recursive(
path,
exclude_files=(
self.exclude_files_from_base_folder if idx == 0 else None
),
)
for idx, path in enumerate(self._paths)
]
return reduce(
lambda last, config: self._merge_configs(
last, config, copy_trees=True
),
lambda last, config: self._merge_configs(last, config, copy_trees=True),
configs + [extra_config_values],
ConfigTree(),
)
@@ -141,9 +153,14 @@ class BasicConfig:
for key, value in b.items():
override = key.startswith(override_prefix)
if override:
key = key[len(override_prefix):]
key = key[len(override_prefix) :]
# if key is in both a and b and both values are dictionary then merge it otherwise override it
if not override and key in a and isinstance(a[key], ConfigTree) and isinstance(b[key], ConfigTree):
if (
not override
and key in a
and isinstance(a[key], ConfigTree)
and isinstance(b[key], ConfigTree)
):
if copy_trees:
a[key] = a[key].copy()
cls._merge_configs(a[key], b[key], copy_trees=copy_trees)
@@ -156,13 +173,15 @@ class BasicConfig:
a[key] = value
if a.root:
if b.root:
a.history[key] = a.history.get(key, []) + b.history.get(key, [value])
a.history[key] = a.history.get(key, []) + b.history.get(
key, [value]
)
else:
a.history[key] = a.history.get(key, []) + [value]
return a
def _read_recursive(self, conf_root) -> ConfigTree:
def _read_recursive(self, conf_root, exclude_files: Set[str]) -> ConfigTree:
conf = ConfigTree()
if not conf_root:
@@ -180,6 +199,8 @@ class BasicConfig:
print(f"Loading config from {conf_root}")
for file in conf_root.rglob("*.conf"):
if exclude_files and file.name in exclude_files:
continue
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
conf.put(key, self._read_single_file(file))

View File

@@ -2,8 +2,8 @@
watch: false # Watch for changes (dev only)
debug: false # Debug mode
pretty_json: false # prettify json response
return_stack: true # return stack trace on error
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
return_stack: false # return stack trace on error
return_stack_to_caller: false # top-level control on whether to return stack trace in an API response
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
# valid values are:
@@ -41,10 +41,7 @@
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
strict: false
aggregate {
allow_disk_use: true
}
ensure_db_version_on_startup: true
}
elastic {
@@ -62,6 +59,9 @@
# verify user tokens
verify_user_tokens: false
# If set then users that were created from secure credentials or fixed user settings and are no longer in these settings will be deleted on startup
delete_missing_autocreated_users: true
# max token expiration timeout in seconds (1 year)
max_expiration_sec: 31536000
@@ -76,6 +76,7 @@
httponly: true # allow only http to access the cookies (no JS etc)
secure: false # not using HTTPS
domain: null # Limit to localhost is not supported
samesite: Lax
max_age: 99999999999
}
@@ -117,6 +118,10 @@
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600
# Timeout in seconds for worker registration (or status report). If a worker did not report for this long,
# it is discarded from the server's table
default_timeout: 600
}
check_for_updates {

View File

@@ -2,10 +2,9 @@ fileserver = "http://localhost:8081"
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
hosts: [{host: "127.0.0.1", port: 9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}
@@ -13,10 +12,9 @@ elastic {
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
hosts: [{host:"127.0.0.1", port:9200, scheme: http}]
args {
timeout: 60
dead_timeout: 10
max_retries: 3
retry_on_timeout: true
}

View File

@@ -1,13 +1,13 @@
{
http {
session_secret {
apiserver: "Gx*gB-L2U8!Naqzd#8=7A4&+=In4H(da424H33ZTDQRGF6=FWw"
apiserver: "V8gcW3EneNDcNfO7G_TSUsWe7uLozyacc9_I33o7bxUo8rCN31VLRg"
}
}
auth {
# token sign secret
token_secret: "7E1ua3xP9GT2(cIQOfhjp+gwN6spBeCAmN-XuugYle00I=Wc+u"
token_secret: "Rq8FW84sSqVgq7WvBB_4EzNl9y8z8IGiDXX3C345_a5AZfcwZcwCIA"
}
credentials {
@@ -15,19 +15,29 @@
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
user_secret: "gaOfhDX2-bpkeI7-cwEcaMuGijxaG2UG3jbIvg4DxmVGF0LNI7rgvCb1-ne38IlBo1w"
}
fileserver {
role: "system"
user_key: "GSQWPEKSKNKF354LC9V6BHXKTYFD5I"
user_secret: "tuBXcGQBECsEhcNiK2kiWi750z9r8Z85XrQ9V0c24huTuCb2xf2X1nKG"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
user_secret: "XhkH6a6ds9JBnM_MrahYyYdO-wS2bqFSm8gl-V0UZXH26Ydd6Eyi28TeBEoSr6Z3Bes"
revoke_in_fixed_mode: true
}
services_agent {
role: "admin"
user_key: ""
user_secret: ""
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
user_secret: "LPEJbGJ6bK4tujQcmrD3i1dbMBDdwUwelVa-LG0K0FFmY9bzH_H0Sw"
revoke_in_fixed_mode: true
}
}

View File

@@ -2,3 +2,8 @@ max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600
allow_disk_use {
sort: true
aggregate: true
}

View File

@@ -1,4 +1,4 @@
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
# if set to true then on task delete/reset external file urls for known storage types are scheduled for async delete
# otherwise they are returned to a client for the client side delete
enabled: true
max_retries: 3
@@ -9,4 +9,5 @@ fileserver {
# Can be in the form <schema>://host:port/path or /path
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
timeout_sec: 300
token_expiration_sec: 600
}

View File

@@ -32,6 +32,8 @@ events_retrieval {
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
multi_plots_batch_size: 1000
}
# if set then plot str will be checked for the valid json on plot add

View File

@@ -1,7 +1,4 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -1,3 +1,9 @@
tags_cache {
expiration_seconds: 3600
}
download {
redis_timeout_sec: 300
batch_size: 500
max_download_items: 50000
max_project_name_length: 60
}

View File

@@ -0,0 +1,7 @@
default_container_timeout_sec: 600
# Auto-register unknown serving containers on status reports and other calls
container_auto_register: true
# Assume unknow serving containers have unregistered (i.e. do not raise unregistered error)
container_auto_unregister: true
# The minimal sampling interval for serving model monitor chars
min_chart_interval_sec: 40

View File

@@ -15,14 +15,15 @@ aws {
# key: "my-access-key"
# secret: "my-secret-key"
# },
{
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
host: "localhost:9000"
key: "evg_user"
secret: "evg_pass"
multipart: false
secure: false
}
// {
// # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
// host: "localhost:9000"
// key: "minioadmin"
// secret: "minioadmin"
// # region: my-server
// multipart: false
// secure: false
// }
]
}
}

View File

@@ -11,9 +11,6 @@ non_responsive_tasks_watchdog {
multi_task_histogram_limit: 100
hyperparam_values {
# maximal amount of distinct hyperparam values to retrieve
max_count: 100
# max allowed outdate time for the cashed result
cache_allowed_outdate_sec: 60
@@ -26,4 +23,6 @@ hyperparam_values {
max_last_metrics: 2000
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
async_events_delete: false
async_events_delete: true
# do not use async_delete if the deleted task has amount of events lower than this threshold
async_events_delete_threshold: 100000

View File

@@ -0,0 +1,9 @@
default_worker_timeout_sec: 600
default_cluster_timeout_sec: 600
# The minimal sampling interval for resource dashboard and worker activity charts
min_chart_interval_sec: 40
stats {
max_metrics_concurrency: 4
}

View File

@@ -37,6 +37,8 @@ OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
class DatabaseEntry(models.Base):
host = StringField(required=True)
alias = StringField()
name = StringField()
db = StringField()
class DatabaseFactory:
@@ -78,10 +80,13 @@ class DatabaseFactory:
missing.append(key)
continue
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
settings = {**db_entries.get(key)}
if not any(field in settings for field in ("name", "db")):
settings["name"] = key
entry = cls._create_db_entry(alias=alias, settings=settings)
if override_connection_string:
con_str = f"{override_connection_string.rstrip('/')}/{key}"
con_str = override_connection_string
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
entry.host = con_str
else:

View File

@@ -5,7 +5,7 @@ from textwrap import shorten
import dpath
from dpath.exceptions import InvalidKeyName
from elasticsearch import ElasticsearchException
from elastic_transport import TransportError, ApiError
from elasticsearch.helpers import BulkIndexError
from jsonmodels.errors import ValidationError as JsonschemaValidationError
from mongoengine.errors import (
@@ -16,7 +16,7 @@ from mongoengine.errors import (
LookUpError,
InvalidQueryError,
)
from pymongo.errors import PyMongoError, NotMasterError
from pymongo.errors import PyMongoError, NotPrimaryError
from apiserver.apierrors import errors
@@ -198,7 +198,7 @@ def translate_errors_context(message=None, **kwargs):
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
except PyMongoError as e:
raise errors.server_error.InternalError(message, err=str(e))
except NotMasterError as e:
except NotPrimaryError as e:
raise errors.server_error.InternalError(message, err=str(e))
except MakeGetAllQueryError as e:
raise errors.bad_request.ValidationError(e.error, field=e.field)
@@ -210,9 +210,9 @@ def translate_errors_context(message=None, **kwargs):
raise errors.bad_request.ValidationError(e.args[0])
except BulkIndexError as e:
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
except ElasticsearchException as e:
except (TransportError, ApiError) as e:
raise errors.server_error.DataError(e, message, **kwargs)
except InvalidKeyName:
raise errors.server_error.DataError("invalid empty key encountered in data")
except Exception as ex:
except Exception:
raise

View File

@@ -4,6 +4,7 @@ from mongoengine import (
EmbeddedDocumentListField,
EmailField,
DateTimeField,
BooleanField,
)
from apiserver.database import Database, strict
@@ -76,3 +77,6 @@ class User(DbModelMixin, AuthDocument):
email = EmailField(unique=True, sparse=True)
""" Email uniquely identifying the user """
autocreated = BooleanField(default=False)
""" Set to true if the user was auto created based on config settings"""

View File

@@ -1,5 +1,6 @@
import re
from collections import namedtuple
from collections import defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
Collection,
@@ -11,17 +12,18 @@ from typing import (
Mapping,
Any,
Callable,
Dict,
List,
Generator,
)
import attr
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors, APIError
from apiserver.apierrors.base import BaseError
from apiserver.apierrors.errors.bad_request import FieldsValueError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
@@ -39,7 +41,7 @@ from apiserver.redis_manager import redman
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
log = config.logger("dbmodel")
mongo_conf = config.get("services._mongo")
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
@@ -105,7 +107,18 @@ class GetMixin(PropsMixin):
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
@attr.s(auto_attribs=True)
class MultiFieldParameters:
fields: Sequence[str]
pattern: str = None
datetime: Union[list, str] = None
def __attrs_post_init__(self):
if not any(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Either 'pattern' or 'datetime' should be provided")
if all(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Only one of the 'pattern' and 'datetime' can be provided")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
@@ -132,90 +145,134 @@ class GetMixin(PropsMixin):
self.range_fields = range_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
class NewListFieldBucketHelper:
op_prefix = "__$"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
_unary_operators = {
"__$not": False,
}
_reset_operator = "__$nop"
_operators = {
"__$all": Q.AND,
"__$and": Q.AND,
"__$any": Q.OR,
"__$or": Q.OR,
}
default_global_operator = Q.AND
default_context = Q.OR
# not_all modifier currently not supported due to the backwards compatibility
mongo_modifiers = {
Q.AND: {True: "all", False: "nin"},
Q.OR: {True: "in", False: "nin"},
}
def __init__(self, field, legacy=False):
@attr.s(auto_attribs=True)
class Term:
operator: str = None
reset: bool = False
include: bool = True
value: str = None
def __init__(self, field: str, data: Sequence[str], legacy=False):
self._field = field
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
self.global_operator = None
self.actions = defaultdict(list)
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
try:
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
except AttributeError:
raise errors.bad_request.FieldsValueError(
"invalid value type, string expected", field=self._field, value=str(v)
)
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self.allow_empty = True
return None
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
return self.default_mongo_op
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
self._support_legacy = legacy
current_context = self.default_context
for d in self._get_next_term(data):
if d.operator is not None:
current_context = d.operator
self._support_legacy = False
if self.global_operator is None:
self.global_operator = d.operator
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
if self.global_operator is None:
self.global_operator = self.default_global_operator
if d.reset:
current_context = self.default_context
self._support_legacy = legacy
continue
if d.value is None:
self.allow_empty = True
continue
self.actions[self.mongo_modifiers[current_context][d.include]].append(
d.value
)
if self.global_operator is None:
self.global_operator = self.default_global_operator
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
for value in data:
if value is None:
unary_operator = None
yield self.Term()
continue
if not isinstance(value, str):
raise FieldsValueError(
"invalid value type, string expected",
field=self._field,
value=str(value),
)
if value == self._reset_operator:
unary_operator = None
yield self.Term(reset=True)
continue
if value.startswith(self.op_prefix):
if unary_operator:
raise FieldsValueError(
"Value is expected after",
field=self._field,
operator=unary_operator,
)
if value in self._unary_operators:
unary_operator = value
continue
operator = self._operators.get(value)
if operator is None:
raise FieldsValueError(
"Unsupported operator",
field=self._field,
operator=value,
)
yield self.Term(operator=operator)
continue
if (
not unary_operator
and self._support_legacy
and value.startswith("-")
):
value = value[1:]
if not value:
raise FieldsValueError(
"Missing value after the exclude prefix -",
field=self._field,
value=value,
)
yield self.Term(value=value, include=False)
continue
term = self.Term(value=value)
if unary_operator:
term.include = self._unary_operators[unary_operator]
unary_operator = None
yield term
if unary_operator:
raise FieldsValueError(
"Value is expected after", operator=unary_operator
)
get_all_query_options = QueryParameterOptions()
@@ -233,8 +290,8 @@ class GetMixin(PropsMixin):
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
expiration_interval=mongo_conf.get(
"scroll_state_expiration_seconds", 600
),
)
@@ -277,6 +334,8 @@ class GetMixin(PropsMixin):
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
- <any|all>: {fields: [<field1>, <field2>, ...], datetime: <datetime condition>} Will query for items where any or all
provided datetime fields match the provided condition.
:return: mongoengine.Q query object
"""
return cls._prepare_query_no_company(
@@ -330,6 +389,46 @@ class GetMixin(PropsMixin):
return cls._try_convert_to_numeric(value)
return value
@classmethod
def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]:
"""
Return dates query for the field
If the data is 2 values array and none of the values starts from dates comparison operations
then return the simplified range query
Otherwise return the dictionary of dates conditions
"""
if not isinstance(data, list):
data = [data]
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
return cls.get_range_field_query(field, data)
dict_query = {}
for d in data:
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
return dict_query
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@@ -360,12 +459,25 @@ class GetMixin(PropsMixin):
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
filters = parameters.pop("filters", {})
if not isinstance(filters, dict):
raise FieldsValueError(
"invalid value type, string expected",
field=filters,
value=str(filters),
)
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=filters
).items():
query &= cls.get_list_filter_query(field, data)
parameters.pop(field, None)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
@@ -387,33 +499,11 @@ class GetMixin(PropsMixin):
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
# date time fields also support simplified range queries. Check if this is the case
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else:
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
dates_q = cls._get_dates_query(field, data)
if isinstance(dates_q, Q):
query &= dates_q
elif isinstance(dates_q, dict):
dict_query.update(dates_q)
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
@@ -425,33 +515,48 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
if data.pattern is not None:
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
),
data.fields,
RegexQ(),
)
data.fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
date_fields = [field for field in data.fields if field in opts.datetime_fields]
if not date_fields:
break
q = Q()
for date_f in date_fields:
dates_q = cls._get_dates_query(date_f, data.datetime)
if isinstance(dates_q, dict):
dates_q = RegexQ(**dates_q)
q = func(q, dates_q)
query = query & q
except APIError:
raise
except Exception as ex:
raise errors.bad_request.FieldsValueError(
"failed parsing query field", error=str(ex), **({"field": field} if field else {})
"failed parsing query field",
error=str(ex),
**({"field": field} if field else {}),
)
return query & RegexQ(**dict_query)
@@ -487,6 +592,149 @@ class GetMixin(PropsMixin):
return q
@attr.s(auto_attribs=True)
class ListQueryFilter:
"""
Deserialize filters data and build db_query object that represents it with the corresponding
mongo engine operations
Each part has include and exclude lists that map to mongoengine operations as following:
"any"
- include -> 'in'
- exclude -> 'not_all'
- combined by 'or' operation
"all"
- include -> 'all'
- exclude -> 'nin'
- combined by 'and' operation
"op" optional parameter for combining "and" and "all" parts. Can be "and" or "or". The default is "and"
"""
_and_op = "and"
_or_op = "or"
_allowed_op = [_and_op, _or_op]
_db_modifiers: Mapping = {
(Q.OR, True): "in",
(Q.OR, False): "not__all",
(Q.AND, True): "all",
(Q.AND, False): "nin",
}
@attr.s(auto_attribs=True)
class ListFilter:
include: Sequence[str] = []
exclude: Sequence[str] = []
@classmethod
def from_dict(cls, d: Mapping):
if d is None:
return None
return cls(**d)
any: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
all: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
op: str = attr.ib(default="and")
db_query: dict = attr.ib(init=False)
# noinspection PyUnresolvedReferences
@op.validator
def op_validator(self, _, value):
if value not in self._allowed_op:
raise ValueError(
f"Invalid list query filter operator: {value}. "
f"Should be one of {str(self._allowed_op)}"
)
@property
def and_op(self) -> bool:
return self.op == self._and_op
def __attrs_post_init__(self):
self.db_query = {}
for op, conditions in ((Q.OR, self.any), (Q.AND, self.all)):
if not conditions:
continue
operations = {}
for vals, include in (
(conditions.include, True),
(conditions.exclude, False),
):
if not vals:
continue
unique = set(vals)
if None in unique:
# noinspection PyTypeChecker
unique.remove(None)
if include:
operations["size"] = 0
else:
operations["not__size"] = 0
if not unique:
continue
operations[self._db_modifiers[(op, include)]] = list(unique)
self.db_query[op] = operations
@classmethod
def from_data(cls, field, data: Mapping):
if not isinstance(data, dict):
raise errors.bad_request.ValidationError(
"invalid filter for field, dictionary expected",
field=field,
value=str(data),
)
try:
return cls(**data)
except Exception as ex:
raise errors.bad_request.ValidationError(
field=field,
value=str(ex),
)
@classmethod
def get_list_filter_query(
cls, field: str, data: Mapping
) -> Union[RegexQ, RegexQCombination]:
if not data:
return RegexQ()
filter_ = cls.ListQueryFilter.from_data(field, data)
mongoengine_field = field.replace(".", "__")
queries = []
for op, actions in filter_.db_query.items():
if not actions:
continue
ops = []
for action, vals in actions.items():
# cannot just check vals here since 0 is acceptable value
if vals is None or vals == []:
continue
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
if not ops:
continue
if len(ops) == 1:
queries.extend(ops)
continue
queries.append(RegexQCombination(operation=op, children=ops))
if not queries:
return RegexQ()
if len(queries) == 1:
return queries[0]
operation = Q.AND if filter_.and_op else Q.OR
return RegexQCombination(operation=operation, children=queries)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
@@ -501,15 +749,15 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)):
data = [data]
helper = cls.ListFieldBucketHelper(field, legacy=True)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
helper = cls.NewListFieldBucketHelper(field, data=data, legacy=True)
global_op = helper.global_operator
actions = helper.actions
mongoengine_field = field.replace(".", "__")
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
RegexQ(**{f"{mongoengine_field}__{action}": list(set(values))})
for action, values in actions.items()
]
if not queries:
@@ -570,7 +818,7 @@ class GetMixin(PropsMixin):
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
max_page_size = mongo_conf.get("max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
@@ -595,7 +843,7 @@ class GetMixin(PropsMixin):
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
"""Extract a projection list from the provided dictionary. Supports an override projection."""
if override_projection is not None:
return override_projection
if not parameters:
@@ -609,7 +857,8 @@ class GetMixin(PropsMixin):
"""Return include and exclude lists based on passed projection and class definition"""
if projection:
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
projection,
key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
else:
include, exclude = [], []
@@ -748,7 +997,9 @@ class GetMixin(PropsMixin):
@classmethod
def _get_collation_override(cls, field: str) -> Optional[dict]:
return first(
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
v
for k, v in cls._field_collation_overrides.items()
if field.startswith(k) or field.startswith(f"-{k}")
)
@classmethod
@@ -854,7 +1105,9 @@ class GetMixin(PropsMixin):
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
query_dict=query_dict,
data_getter=data_getter,
ret_params=ret_params,
)
return cls._get_many_no_company(
@@ -867,7 +1120,9 @@ class GetMixin(PropsMixin):
@classmethod
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
cls,
query: Q = None,
projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
@@ -880,6 +1135,13 @@ class GetMixin(PropsMixin):
return cls._get_many_no_company(query=_query, override_projection=projection)
@staticmethod
def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence):
disk_use_setting = mongo_conf.get("allow_disk_use.sort", None)
if disk_use_setting is not None:
qs = qs.allow_disk_use(disk_use_setting)
return qs.order_by(*order_by)
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
@@ -921,7 +1183,7 @@ class GetMixin(PropsMixin):
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = qs.order_by(*order_by)
qs = cls._get_qs_with_ordering(qs, order_by)
if include:
# add projection
@@ -1013,7 +1275,7 @@ class GetMixin(PropsMixin):
res = cls._get_queries_for_order_field(query, order_field)
if res:
query_sets = [cls.objects(q) for q in res]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
query_sets = [cls._get_qs_with_ordering(qs, order_by) for qs in query_sets]
if order_field and not override_collation:
override_collation = cls._get_collation_override(order_field)
@@ -1078,22 +1340,6 @@ class GetMixin(PropsMixin):
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):
__user_set_allowed_fields = None
@@ -1153,7 +1399,7 @@ class UpdateMixin(object):
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
"""Provide convenience methods for a subclass of mongoengine.Document"""
@classmethod
def aggregate(
@@ -1173,7 +1419,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
kwargs.update(
allowDiskUse=allow_disk_use
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
else mongo_conf.get("allow_disk_use.aggregate", True)
)
return cls.objects.aggregate(pipeline, **kwargs)
@@ -1181,25 +1427,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
def set_public(
cls: Type[Document],
company_id: str,
user_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, set__company="")
update: dict = dict(set__company_origin=company_id, set__company="")
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
id__in=ids, company="", company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
update: dict = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
if hasattr(cls, "last_change"):
update["set__last_change"] = datetime.utcnow()
if hasattr(cls, "last_changed_by"):
update["set__last_changed_by"] = user_id
return {"updated": cls.objects(id__in=ids).update(**update)}

View File

@@ -3,6 +3,8 @@ from mongoengine import (
DateTimeField,
BooleanField,
EmbeddedDocumentField,
IntField,
ListField,
)
from apiserver.database import Database, strict
@@ -17,12 +19,14 @@ from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.project import Project
from apiserver.database.model.task.metrics import MetricEvent
from apiserver.database.model.task.task import Task
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
}
meta = {
@@ -33,10 +37,18 @@ class Model(AttributedDocument):
"project",
"task",
"last_update",
("company", "framework"),
("company", "last_update"),
("company", "name"),
("company", "user"),
("company", "uri"),
# distinct queries support
("company", "tags"),
("company", "system_tags"),
("company", "project", "tags"),
("company", "project", "system_tags"),
("company", "user"),
("company", "project", "user"),
("company", "framework"),
("company", "project", "framework"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -67,7 +79,8 @@ class Model(AttributedDocument):
"parent",
"metadata.*",
),
datetime_fields=("last_update",),
range_fields=("created", "last_metrics.*", "last_iteration"),
datetime_fields=("last_update", "last_change"),
)
id = StringField(primary_key=True)
@@ -85,6 +98,8 @@ class Model(AttributedDocument):
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
last_change = DateTimeField()
last_changed_by = StringField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
@@ -92,6 +107,9 @@ class Model(AttributedDocument):
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)
last_iteration = IntField(default=0)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@@ -47,6 +47,7 @@ class Queue(DbModelMixin, Document):
name = StrippedStringField(
required=True, unique_with="company", min_length=3, user_set_allowed=True
)
display_name = StringField(user_set_allowed=True)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(

View File

@@ -0,0 +1,76 @@
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
BooleanField,
)
from apiserver.database import Database, strict
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import ProperDictMixin
class AWSBucketSettings(EmbeddedDocument, ProperDictMixin):
bucket = StringField()
subdir = StringField()
host = StringField()
key = StringField()
secret = StringField()
token = StringField()
multipart = BooleanField()
acl = StringField()
secure = BooleanField()
region = StringField()
verify = BooleanField()
use_credentials_chain = BooleanField()
class AWSSettings(EmbeddedDocument, DbModelMixin):
key = StringField()
secret = StringField()
region = StringField()
token = StringField()
use_credentials_chain = BooleanField()
buckets = EmbeddedDocumentListField(AWSBucketSettings)
class GoogleBucketSettings(EmbeddedDocument, ProperDictMixin):
bucket = StringField()
subdir = StringField()
project = StringField()
credentials_json = StringField()
class GoogleStorageSettings(EmbeddedDocument, DbModelMixin):
project = StringField()
credentials_json = StringField()
buckets = EmbeddedDocumentListField(GoogleBucketSettings)
class AzureStorageContainerSettings(EmbeddedDocument, ProperDictMixin):
account_name = StringField(required=True)
account_key = StringField(required=True)
container_name = StringField()
class AzureStorageSettings(EmbeddedDocument, DbModelMixin):
containers = EmbeddedDocumentListField(AzureStorageContainerSettings)
class StorageSettings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"company"
],
}
id = StringField(primary_key=True)
company = StringField(required=True, unique=True)
last_update = DateTimeField()
aws: AWSSettings = EmbeddedDocumentField(AWSSettings)
google: GoogleStorageSettings = EmbeddedDocumentField(GoogleStorageSettings)
azure: AzureStorageSettings = EmbeddedDocumentField(AzureStorageSettings)

View File

@@ -5,6 +5,7 @@ from mongoengine import (
LongField,
EmbeddedDocumentField,
IntField,
FloatField,
)
from apiserver.database.fields import SafeMapField
@@ -23,6 +24,11 @@ class MetricEvent(EmbeddedDocument):
min_value_iteration = IntField()
max_value = DynamicField() # for backwards compatibility reasons
max_value_iteration = IntField()
first_value = FloatField()
first_value_iteration = IntField()
count = IntField()
mean_value = FloatField()
x_axis_label = StringField()
class EventStats(EmbeddedDocument):

View File

@@ -19,6 +19,7 @@ from apiserver.database.fields import (
SafeSortedListField,
EmbeddedDocumentListField,
NullableStringField,
NoneType,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument):
content_size = LongField()
timestamp = LongField()
type_data = EmbeddedDocumentField(ArtifactTypeData)
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
display_data = SafeSortedListField(
ListField(UnionField((int, float, str, NoneType)))
)
class ParamsItem(EmbeddedDocument, ProperDictMixin):
@@ -180,9 +183,8 @@ class Task(AttributedDocument):
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
("company", "system_tags", "last_update"),
("company", "last_update", "system_tags"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
@@ -190,6 +192,17 @@ class Task(AttributedDocument):
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
# distinct queries support
("company", "tags"),
("company", "system_tags"),
("company", "project", "tags"),
("company", "project", "system_tags"),
("company", "user"),
("company", "project", "user"),
("company", "parent"),
("company", "project", "parent"),
("company", "type"),
("company", "project", "type"),
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
@@ -227,10 +240,13 @@ class Task(AttributedDocument):
"project",
"parent",
"hyperparams.*",
"execution.queue",
"models.input.model",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update", "last_change"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*",),
)
id = StringField(primary_key=True)
@@ -245,6 +261,7 @@ class Task(AttributedDocument):
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
report = StringField()
report_assets = ListField(StringField())
created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField()
completed = DateTimeField()
@@ -266,7 +283,7 @@ class Task(AttributedDocument):
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
duration = IntField() # obsolete, do not use
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)

View File

@@ -20,4 +20,5 @@ class User(DbModelMixin, Document):
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = DynamicField(default="", exclude_by_default=True)
created_in_version = StringField()
created = DateTimeField()

View File

@@ -3,10 +3,14 @@ from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
from typing import Sequence, Dict, Callable
from boltons import iterutils
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.props import PropsMixin
SEP = "."
max_items_per_fetch = config.get("services._mongo.max_page_size", 500)
class _ReferenceProxy(dict):
@@ -278,10 +282,11 @@ class ProjectionHelper(object):
doc_only = list(filter(None, data["only"]))
doc_only = list({"id"} | set(doc_only)) if doc_only else None
for res in projection_func(
doc_type=doc_type, projection=doc_only, ids=ids
):
self._proxy_manager.update(res)
for ids_chunk in iterutils.chunked_iter(ids, max_items_per_fetch):
for res in projection_func(
doc_type=doc_type, projection=doc_only, ids=ids_chunk
):
self._proxy_manager.update(res)
if len(ref_projection) == 1:
do_projection(items[0])

View File

@@ -121,8 +121,8 @@ def init_cls_from_base(cls, instance):
)
def get_company_or_none_constraint(company=None):
return Q(company__in=(company, None, "")) | Q(company__exists=False)
def get_company_or_none_constraint(company=""):
return Q(company__in=list({company, ""}))
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:

View File

@@ -0,0 +1,25 @@
### Supported api versions
| Release | ApiVersion |
|---------|------------|
| v2.1 | 2.32 |
| v2.0 | 2.31 |
| v1.17 | 2.31 |
| v1.16 | 2.30 |
| v1.15 | 2.29 |
| v1.14 | 2.28 |
| v1.13 | 2.27 |
| v1.12 | 2.26 |
| v1.11 | 2.25 |
| v1.10 | 2.24 |
| v1.9 | 2.23 |
| v1.8 | 2.22 |
| v1.7 | 2.21 |
| v1.6 | 2.20 |
| v1.5 | 2.19 |
| v1.4 | 2.18 |
| v1.3 | 2.17 |
| v1.2 | 2.16 |
| v1.1 | 2.15 |
| v1.0 | 2.14 |
| v0.17 | 2.13 |

View File

@@ -4,34 +4,89 @@ Apply elasticsearch mappings to given hosts.
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional, Sequence, Tuple
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, exceptions
HERE = Path(__file__).resolve().parent
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
hosts: Sequence,
key: Optional[str] = None,
es_args: dict = None,
http_auth: Tuple = None,
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def _send_template(f):
with f.open() as json_data:
data = json.load(json_data)
template_name = f.stem
res = es.indices.put_template(template_name, body=data)
return {"mapping": template_name, "result": res}
def _send_component_template(ct_file):
with ct_file.open() as json_data:
body = json.load(json_data)
template_name = f"{ct_file.stem}"
res = es.cluster.put_component_template(name=template_name, body=body)
return {"component_template": template_name, "result": res}
p = HERE / "mappings"
if key:
files = (p / key).glob("*.json")
else:
files = p.glob("**/*.json")
def _send_index_template(it_file):
with it_file.open() as json_data:
body = json.load(json_data)
template_name = f"{it_file.stem}"
res = es.indices.put_index_template(name=template_name, body=body)
return {"index_template": template_name, "result": res}
# def _send_legacy_template(f):
# with f.open() as json_data:
# data = json.load(json_data)
# template_name = f.stem
# res = es.indices.put_template(name=template_name, body=data)
# return {"mapping": template_name, "result": res}
def _delete_legacy_templates(legacy_folder):
res_list = []
for lt in legacy_folder.glob("*.json"):
template_name = lt.stem
try:
if not es.indices.get_template(name=template_name):
continue
res = es.indices.delete_template(name=template_name)
except exceptions.NotFoundError:
continue
res_list.append({"deleted legacy mapping": template_name, "result": res})
return res_list
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
return [_send_template(f) for f in files]
root = HERE / "index_templates"
if key:
folders = [root / key]
else:
folders = [f for f in root.iterdir() if f.is_dir()]
ret = []
for f in folders:
for ct in (f / "component_templates").glob("*.json"):
ret.append(_send_component_template(ct))
for it in f.glob("*.json"):
ret.append(_send_index_template(it))
legacy_root = HERE / "mappings"
for f in folders:
legacy_f = legacy_root / f.stem
if not legacy_f.exists() or not legacy_f.is_dir():
continue
ret.extend(_delete_legacy_templates(legacy_f))
return ret
# p = HERE / "mappings"
# if key:
# files = (p / key).glob("*.json")
# else:
# files = p.glob("**/*.json")
#
# return [_send_template(f) for f in files]
def parse_args():

View File

@@ -0,0 +1,48 @@
{
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"@timestamp": {
"type": "date"
},
"task": {
"type": "keyword"
},
"type": {
"type": "keyword"
},
"worker": {
"type": "keyword"
},
"timestamp": {
"type": "date"
},
"iter": {
"type": "long"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"company_id": {
"type": "keyword"
},
"model_event": {
"type": "boolean"
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
{
"index_patterns": "events-log-*",
"template": {
"mappings": {
"properties": {
"msg": {
"type": "text",
"index": false
},
"level": {
"type": "keyword"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,18 @@
{
"index_patterns": "events-plot-*",
"template": {
"mappings": {
"properties": {
"plot_str": {
"type": "text",
"index": false
},
"plot_data": {
"type": "binary"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,17 @@
{
"index_patterns": "events-training_debug_image-*",
"template": {
"mappings": {
"properties": {
"key": {
"type": "keyword"
},
"url": {
"type": "keyword"
}
}
}
},
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,5 @@
{
"index_patterns": "events-training_stats_scalar-*",
"priority": 500,
"composed_of": ["events_common"]
}

View File

@@ -0,0 +1,31 @@
{
"index_patterns": "queue_metrics_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
},
"company_id": {
"type": "keyword"
}
}
}
}
}

View File

@@ -0,0 +1,79 @@
{
"index_patterns": "serving_stats_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"container_id": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
},
"endpoint_url": {
"type": "keyword"
},
"requests_num": {
"type": "integer"
},
"requests_min": {
"type": "float"
},
"uptime_sec": {
"type": "integer"
},
"latency_ms": {
"type": "integer"
},
"cpu_usage": {
"type": "float"
},
"cpu_num": {
"type": "integer"
},
"gpu_usage": {
"type": "float"
},
"gpu_num": {
"type": "integer"
},
"memory_used": {
"type": "float"
},
"memory_free": {
"type": "float"
},
"memory_total": {
"type": "float"
},
"gpu_memory_used": {
"type": "float"
},
"gpu_memory_free": {
"type": "float"
},
"gpu_memory_total": {
"type": "float"
},
"disk_free_home": {
"type": "float"
},
"network_rx": {
"type": "float"
},
"network_tx": {
"type": "float"
}
}
}
}
}

View File

@@ -0,0 +1,43 @@
{
"index_patterns": "worker_stats_*",
"template": {
"settings": {
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"worker": {
"type": "keyword"
},
"category": {
"type": "keyword"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"unit": {
"type": "keyword"
},
"task": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
}
}
}
}
}

View File

@@ -10,6 +10,8 @@ from apiserver.config_repo import config
from apiserver.elastic.apply_mappings import apply_mappings_to_cluster
log = config.logger(__file__)
logging.getLogger("elasticsearch").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
class MissingElasticConfiguration(Exception):
@@ -78,6 +80,18 @@ def check_elastic_empty() -> bool:
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
)
def events_legacy_template():
try:
return es.indices.get_template(name="events*")
except exceptions.NotFoundError:
return False
def events_template():
try:
return es.indices.get_index_template(name="events*")
except exceptions.NotFoundError:
return False
try:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
@@ -85,12 +99,9 @@ def check_elastic_empty() -> bool:
es = Elasticsearch(
hosts=cluster_conf.get("hosts", None),
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {})
**cluster_conf.get("args", {}),
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
return not (events_template() or events_legacy_template())
except exceptions.ConnectionError as ex:
if retry >= max_retries - 1:
raise ElasticConnectionError(
@@ -115,5 +126,7 @@ def init_es_data():
args = cluster_conf.get("args", {})
http_auth = es_factory.get_credentials(name)
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
res = apply_mappings_to_cluster(
hosts_config, name, es_args=args, http_auth=http_auth
)
log.info(res)

View File

@@ -1,8 +1,8 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {

View File

@@ -1,8 +1,8 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
@@ -20,6 +20,9 @@
},
"queue_length": {
"type": "integer"
},
"company_id": {
"type": "keyword"
}
}
}

View File

@@ -1,8 +1,8 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
@@ -32,6 +32,9 @@
},
"task": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
}
}
}

View File

@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from functools import lru_cache
from os import getenv
@@ -9,6 +10,8 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
log = config.logger(__file__)
logging.getLogger('elasticsearch').setLevel(logging.WARNING)
logging.getLogger('elastic_transport').setLevel(logging.WARNING)
OVERRIDE_HOST_ENV_KEY = (
"CLEARML_ELASTIC_SERVICE_HOST",
@@ -32,6 +35,7 @@ if OVERRIDE_HOST:
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
OVERRIDE_PORT = int(OVERRIDE_PORT)
log.info(f"Using override elastic port {OVERRIDE_PORT}")
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
@@ -67,7 +71,7 @@ class MissingPasswordForElasticUser(Exception):
class ESFactory:
@classmethod
def connect(cls, cluster_name):
def connect(cls, cluster_name) -> Elasticsearch:
"""
Returns the es client for the cluster.
Connects to the cluster if did not connect previously

122
apiserver/fix_mongo_urls.py Normal file
View File

@@ -0,0 +1,122 @@
import logging
from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
)
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
logging.getLogger().setLevel(logging.INFO)
def fix_mongo_urls(mongo_host: str, host_source: str, host_target: str):
logging.info(f"Connecting to Mongo on {mongo_host}")
client = MongoClient(host=mongo_host)
backend_db: Database = client.backend
def get_updated_uri(uri: str):
if not uri or not uri.startswith(host_source):
return
relative_url = uri[len(host_source) :]
return f"{host_target.rstrip('/')}/{relative_url.lstrip('/')}"
host_source = host_source
host_target = host_target
model_collection: Collection = backend_db.get_collection("model")
if model_collection is not None:
logging.info("Updating model uris")
models_count = model_collection.count_documents({})
updated_models = 0
for model in model_collection.find(
{"uri": {"$regex": "^{}".format(host_source)}}, projection=["uri"]
):
updated_uri = get_updated_uri(model.get("uri"))
if updated_uri:
result = model_collection.update_one(
{"_id": model["_id"]}, {"$set": {"uri": updated_uri}}
)
updated_models += result.modified_count
logging.info(f"Updated {updated_models} models from {models_count}")
task_collection: Collection = backend_db.get_collection("task")
if task_collection is not None:
logging.info("Updating task uris")
tasks_count = task_collection.count_documents({})
updated_tasks = 0
for task in task_collection.find(
{"execution.artifacts": {"$exists": 1, "$ne": {}}},
projection=["execution.artifacts"],
):
artifacts = task.get("execution", {}).get("artifacts")
if not artifacts:
continue
uri_updated = False
for artifact in artifacts.values():
updated_uri = get_updated_uri(artifact.get("uri"))
if updated_uri:
artifact["uri"] = updated_uri
uri_updated = True
if uri_updated:
result = task_collection.update_one(
{"_id": task["_id"]}, {"$set": {"execution.artifacts": artifacts}}
)
updated_tasks += result.modified_count
logging.info(f"Updated {updated_tasks} tasks from {tasks_count}")
def normalise_host(host):
if not host.endswith("/"):
return host
return host[:-1]
def main():
def valid_url_prefix(url: str):
if "://" not in url:
raise ArgumentTypeError("url schema is missing")
return url
parser = ArgumentParser(
description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--mongo-host",
"-mh",
type=str,
default="mongodb://mongo:27017",
help="Mongo server host. The default is mongodb://mongo:27017",
)
parser.add_argument(
"--host-source",
"-hs",
type=valid_url_prefix,
required=True,
help="Source host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
)
parser.add_argument(
"--host-target",
"-ht",
type=valid_url_prefix,
required=True,
help="Target host for the files uploaded to the fileserver (in the form http://<host>:<port>)",
)
args = parser.parse_args()
fix_mongo_urls(
mongo_host=args.mongo_host,
host_source=args.host_source,
host_target=args.host_target,
)
logging.info("Completed successfully")
if __name__ == "__main__":
main()

View File

@@ -19,7 +19,9 @@ from google.cloud import storage as google_storage
from mongoengine import Q
from mypy_boto3_s3.service_resource import Bucket as AWSBucket
from apiserver.bll.auth import AuthBLL
from apiserver.bll.storage import StorageBLL
from apiserver.config.info import get_default_company
from apiserver.config_repo import config
from apiserver.database import db
from apiserver.database.model.url_to_delete import UrlToDelete, StorageType, DeletionStatus
@@ -200,6 +202,8 @@ class FileserverStorage(Storage):
res_data = res.json()
return list(res_data.get("deleted", {})), res_data.get("errors", {})
token_expiration_sec = conf.get("fileserver.token_expiration_sec", 600)
def __init__(self, company: str, fileserver_host: str = None):
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
self.host = fileserver_host.rstrip("/")
@@ -220,13 +224,6 @@ class FileserverStorage(Storage):
self.company = company
# @classmethod
# def validate_fileserver_access(cls, fileserver_host: str):
# res = requests.get(
# url=fileserver_host
# )
# res.raise_for_status()
@property
def name(self) -> str:
return "Fileserver"
@@ -260,7 +257,13 @@ class FileserverStorage(Storage):
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
host = base
token = AuthBLL.get_token_for_user(
user_id="__apiserver__",
company_id=get_default_company(),
expiration_sec=self.token_expiration_sec,
).token
session = requests.session()
session.headers.update({"Authorization": "Bearer {}".format(token)})
res = session.get(url=host, timeout=self.Client.timeout)
res.raise_for_status()
@@ -285,6 +288,7 @@ class AzureStorage(Storage):
):
raise ValueError("No path found following container name")
# noinspection PyTypeChecker
return os.path.join(*parsed.path.segments[1:])
@staticmethod
@@ -450,6 +454,7 @@ class AWSStorage(Storage):
else None,
"use_ssl": cfg.secure,
"verify": cfg.verify,
"region_name": cfg.region or None,
}
name = base[len(scheme_prefix(self.scheme)) :]
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name

View File

@@ -3,7 +3,7 @@ from typing import Sequence, Union
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model.auth import Role
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
@@ -60,16 +60,22 @@ def init_mongo_data():
fixed_mode = FixedUser.enabled()
internal_user_emails = set()
for user, credentials in config.get("secure.credentials", {}).items():
email = f"{user}@example.com"
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"email": email,
"key": credentials.user_key,
"secret": credentials.user_secret,
"autocreated": True,
}
internal_user_emails.add(email.lower())
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
user_id = _ensure_auth_user(
user_data, company_id, log=log, revoke=revoke, internal_user=True
)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
@@ -82,8 +88,20 @@ def init_mongo_data():
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, log=log)
ensure_fixed_user(user, log=log, emails=internal_user_emails)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
if internal_user_emails and config.get(
f"apiserver.auth.delete_missing_autocreated_users", True
):
for user in AuthUser.objects(
company=company_id, autocreated=True, email__nin=internal_user_emails
):
log.info(
f"Removing user that is no longer in configuration: {user['id']}\t{user['email']}\t{user['name']}"
)
user.delete()
except Exception as ex:
log.exception("Failed initializing mongodb")
log.exception(f"Failed initializing mongodb: {str(ex)}")

View File

@@ -8,19 +8,28 @@ import pymongo.database
from mongoengine.connection import get_db
from packaging.version import Version, parse
from apiserver.config_repo import config
from apiserver.database import utils
from apiserver.database import Database
from apiserver.database.model.version import Version as DatabaseVersion
from apiserver.utilities.dicts import nested_get
_migrations = "migrations"
_parent_dir = Path(__file__).resolve().parents[1]
_migration_dir = _parent_dir / _migrations
log = config.logger(__file__)
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names() for alias in utils.get_options(Database)
)
for alias in utils.get_options(Database):
collection_names = get_db(alias).list_collection_names()
if collection_names and any(
name in collection_names
for name in ["company", "user", "versions"]
):
return False
return True
def get_last_server_version() -> Version:
@@ -35,6 +44,31 @@ def get_last_server_version() -> Version:
return previous_versions[0] if previous_versions else Version("0.0.0")
def _ensure_mongodb_version():
if not config.get("apiserver.mongo.ensure_db_version_on_startup", True):
return
log.info("Checking DB version")
db: pymongo.database.Database = get_db(Database.backend)
db_version = db.client.server_info()["version"]
if not db_version.startswith("6.0"):
log.warning(f"Database version should be 6.0.x. Instead: {str(db_version)}")
return
res = db.client.admin.command({"getParameter": 1, "featureCompatibilityVersion": 1})
version = nested_get(res, ("featureCompatibilityVersion", "version"))
log.info(f"DB version: {version}")
if version == "6.0":
return
if version != "5.0":
log.warning(f"Cannot upgrade DB version. Should be 5.0. {str(res)}")
return
log.info("Upgrading db version from 5.0 to 6.0")
res = db.client.admin.command({"setFeatureCompatibilityVersion": "6.0"})
log.info(res)
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
@@ -44,6 +78,8 @@ def _apply_migrations(log: Logger):
log.info(f"Started mongodb migrations")
_ensure_mongodb_version()
if not _migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {_migration_dir}")

File diff suppressed because it is too large Load Diff

View File

@@ -9,34 +9,87 @@ from apiserver.database.model.user import User
from apiserver.service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
key, secret = user_data.get("key"), user_data.get("secret")
def _ensure_user_credentials(
user: AuthUser,
key: str,
secret: str,
log: Logger,
revoke: bool = False,
internal_user: bool = False,
) -> None:
if revoke:
log.info(f"Revoking credentials for existing user {user.id} ({user.name})")
user.credentials = []
user.save()
return
if not (key and secret):
credentials = None
else:
creds = Credentials(key=key, secret=secret)
if internal_user:
log.info(f"Resetting credentials for existing user {user.id} ({user.name})")
user.credentials = []
user.save()
return
user = AuthUser.objects(credentials__match=creds).first()
if user:
if revoke:
user.credentials = []
user.save()
return user.id
new_credentials = Credentials(key=key, secret=secret)
if internal_user:
log.info(f"Setting credentials for existing user {user.id} ({user.name})")
user.credentials = [new_credentials]
user.save()
return
credentials = [] if revoke else [creds]
if user.credentials is None:
user.credentials = []
if not any((cred.key, cred.secret) == (key, secret) for cred in user.credentials):
log.info(f"Adding credentials for existing user {user.id} ({user.name})")
user.credentials.append(new_credentials)
user.save()
def _ensure_auth_user(
user_data: dict,
company_id: str,
log: Logger,
revoke: bool = False,
internal_user: bool = False,
) -> str:
user_id = user_data.get("id", f"__{user_data['name']}__")
role = user_data["role"]
email = user_data["email"]
autocreated = user_data.get("autocreated", False)
key, secret = user_data.get("key"), user_data.get("secret")
user: AuthUser = AuthUser.objects(id=user_id).first()
if user:
_ensure_user_credentials(
user=user,
key=key,
secret=secret,
log=log,
revoke=revoke,
internal_user=internal_user,
)
if user.role != role or user.email != email or user.autocreated != autocreated:
user.email = email
user.role = role
user.autocreated = autocreated
user.save()
return user.id
credentials = (
[Credentials(key=key, secret=secret)] if not revoke and key and secret else []
)
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_id,
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
role=role,
email=email,
created=datetime.utcnow(),
credentials=credentials,
autocreated=autocreated,
)
user.save()
@@ -59,23 +112,29 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
return user_id
def ensure_fixed_user(user: FixedUser, log: Logger):
def ensure_fixed_user(user: FixedUser, log: Logger, emails: set):
# noinspection PyTypeChecker
data = attr.asdict(user)
data["id"] = user.user_id
email = f"{user.user_id}@example.com"
data["email"] = email
data["role"] = Role.guest if user.is_guest else Role.user
data["autocreated"] = True
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
try:
log.info(f"Updating user name: {user.name}")
given_name, _, family_name = user.name.partition(" ")
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
db_user.update(
name=user.name, given_name=given_name, family_name=family_name
)
except Exception:
pass
return
else:
_ensure_backend_user(user.user_id, user.company, user.name)
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.guest if user.is_guest else Role.user
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
return _ensure_backend_user(user.user_id, user.company, user.name)
emails.add(email)

View File

@@ -2,7 +2,7 @@ from os import getenv
from boltons.iterutils import first
from redis import StrictRedis
from rediscluster import RedisCluster
from redis.cluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config
@@ -83,7 +83,7 @@ class RedisManager(object):
def host(self, alias):
r = self.connection(alias)
if isinstance(r, RedisCluster):
connections = first(r.connection_pool._available_connections.values())
connections = r.get_default_node().redis_connection.connection_pool._available_connections
else:
connections = r.connection_pool._available_connections

View File

@@ -1,38 +1,38 @@
attrs>=22.1.0
attrs>=22.1.0,<23
azure-storage-blob>=12.13.1
bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
boto3-stubs[s3]>=1.24.35
clearml>=1.6.0,<1.7.0
boto3>=1.26
boto3-stubs[s3]>=1.26
clearml>=1.10.3
dpath>=1.4.2,<2.0
elasticsearch==7.13.3
elasticsearch==8.17.0
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=0.12.2
funcsigs==1.0.2
flask>=2.3.3
furl>=2.0.0
google-cloud-storage==2.0.0
protobuf==3.19.5
gunicorn>=19.7.1
humanfriendly==4.18
jinja2==2.11.3
google-cloud-storage>=2.8.0
gunicorn>=23.0.0
humanfriendly>=4.17
jinja2
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.24.2
mongoengine==0.29.1
nested_dict>=1.61
packaging==20.3
pillow>=10.3.0 # fix vulnerability derived from clearml 1.18.0
psutil>=5.6.5
pyhocon>=0.3.35
pyhocon>=0.3.35r
pyjwt>=2.4.0
pymongo[srv]==3.12.0
pymongo==4.10.1
python-rapidjson>=0.6.3
redis==3.5.3
redis-py-cluster>=2.1.3
redis==5.2.1
requests>=2.13.0
semantic_version>=2.8.3,<3
setuptools>=78.1.1
six
tqdm
validators>=0.12.4
validators>=0.12.4
urllib3>=1.26.18
werkzeug>=3.0.1

Some files were not shown because too many files have changed in this diff Show More