Compare commits

113 Commits

Author SHA1 Message Date
allegroai
6a1fc04d1e Set cookies SameSite value to Lax 2024-02-13 16:18:21 +02:00
allegroai
ee8eb03698 Fix crash when importing events for public company tasks 2024-02-13 16:17:52 +02:00
allegroai
5799baae45 Make sure that APIs that aggregate task/model data from projects can be called for the root project 2024-02-13 16:17:33 +02:00
allegroai
801e536c5e Fix tasks.started to correctly handle null values in the started field 2024-02-13 16:17:02 +02:00
allegroai
6e484ea8f4 Fix missing region parameter when deleting files from minio server 2024-02-13 16:16:24 +02:00
allegroai
a47e65d974 Add input parameters check to multiple APIs 2024-02-13 16:15:55 +02:00
allegroai
702b6dc9c8 Version bump to v1.14.0 2024-01-10 15:31:11 +02:00
allegroai
db15f235e4 Make sure files downloaded from the apiserver are not cached by browsers 2024-01-10 15:31:01 +02:00
allegroai
8c347f8fa9 Fix include and exclude filters not processing "no tags" condition 2024-01-10 15:26:55 +02:00
allegroai
768c3d80ff Remove callback_url_prefix and state parameters from login.supported_modes and does not return urls 2024-01-10 15:26:22 +02:00
allegroai
a5c3ef6385 Fix query filter so that the default operator between different query operations on the same parameter is AND instead of OR 2024-01-10 15:24:37 +02:00
allegroai
11b7a384af Set API version 2.28 2024-01-10 15:23:54 +02:00
allegroai
9a70ade4a6 Support task models with missing model field in data_tool import 2024-01-10 15:18:58 +02:00
allegroai
91ce140901 Add "queue watched" indication to pipelines.start_pipeline 2024-01-10 15:15:43 +02:00
allegroai
49084a9c49 Optimize task statistics for projects dashboard and statistics reporter 2024-01-10 15:13:25 +02:00
allegroai
8a99eb6812 Fix model_metrics parameter name in get_multi_task_metrics schema 2024-01-10 15:12:56 +02:00
allegroai
811ab2bf4f Support exporting users with data tool 2024-01-10 15:12:07 +02:00
allegroai
3752db122b Add events.get_multi_task_metrics 2024-01-10 15:11:27 +02:00
allegroai
439911b84c Upgrade werkzeug and flask dependencies 2024-01-10 15:10:46 +02:00
allegroai
262a301e28 Check for dictionary type for some model and task fields 2024-01-10 15:10:41 +02:00
allegroai
a604451b01 Refactor check for tasks write permission 2024-01-10 15:08:20 +02:00
allegroai
88a7773621 Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots 2024-01-10 15:07:46 +02:00
allegroai
35c4061992 Support filtering by task or model ids in projects.get_unique_metric_variants 2024-01-10 15:06:21 +02:00
allegroai
4684fd5b74 Version bump to v1.13.0 2023-11-17 09:49:26 +02:00
allegroai
e08123fcc0 Fix workers.activity_report should return 0s for the time when no workers reported 2023-11-17 09:49:18 +02:00
allegroai
e713e876eb Upgrade urllib3 requirement 2023-11-17 09:48:19 +02:00
allegroai
c2cc788319 Added supported API versions doc 2023-11-17 09:47:44 +02:00
allegroai
da8315d0db Allow queries on the list of execution queue ids in tasks.get_all/get_all_ex 2023-11-17 09:47:19 +02:00
allegroai
4ac6f88278 Optimize Workers retrieval
Store worker statistics under worker id and not internal redis key
Fix unit tests
2023-11-17 09:46:44 +02:00
allegroai
a7865ccbec Turn on async task events deletion in case there are more than 100_000 events 2023-11-17 09:45:55 +02:00
allegroai
ec14f327c6 Optimize endpoints that do not require authorization by not validating JWT token 2023-11-17 09:45:22 +02:00
allegroai
a03b24d6b6 Add log info on caller IP if token validation fails 2023-11-17 09:43:59 +02:00
allegroai
cb71ef8e47 Fix missing scroll_id in events.get_scalar_metric_data 2023-11-17 09:43:11 +02:00
allegroai
8678fbc995 Fix properly unset Task fields on task reset 2023-11-17 09:42:39 +02:00
allegroai
58df8f201a Update API to 2.27 2023-11-17 09:40:34 +02:00
allegroai
f4bf16c156 Fix schema for swagger compatibility 2023-11-17 09:39:52 +02:00
allegroai
942f996237 Fix async_delete cannot be configured using configuration files 2023-11-17 09:39:22 +02:00
allegroai
c1e7f8f9c1 Optimize deletion of projects with many tasks 2023-11-17 09:38:32 +02:00
allegroai
274c487b37 Add update_tags api to tasks and models 2023-11-17 09:37:25 +02:00
allegroai
cc0129a800 Add filters parameter for passing user defined list filters for all get_all_ex apis 2023-11-17 09:36:58 +02:00
allegroai
388dd1b01f Fix regression issue with archive tasks display 2023-11-17 09:35:55 +02:00
allegroai
d62ecb5e6e Add last_change and last_change_by DB Model 2023-11-17 09:35:22 +02:00
allegroai
6d507616b3 Add pattern parameter to projects.get_hyperparam_values 2023-11-17 09:34:13 +02:00
allegroai
d0252a6dd9 Make sure that hyperparam/configuration/metadata keys that are contain only empty space are rejected 2023-11-17 09:32:22 +02:00
allegroai
2263e7cc1e Fix regression with archive tasks display 2023-07-31 14:16:08 +03:00
allegroai
81b93e6811 Updated dependency - dnspython is a required dependency of pymongo as of pymongo v4.3 (https://pymongo.readthedocs.io/en/stable/changelog.html#changes-in-version-4-3-4-3-2) 2023-07-27 11:49:40 +03:00
allegroai
491e83d0f1 Version bump to v1.12.0 2023-07-26 18:56:04 +03:00
allegroai
f84cc0a2cb Remove 10 metrics limit in multi-task plot comparison 2023-07-26 18:55:49 +03:00
allegroai
6c5f966ed4 Add new_status field to tasks.dequeue and dequeue_many endpoints 2023-07-26 18:55:05 +03:00
allegroai
4eff657810 Fix debug images not returned for tasks in new db 2023-07-26 18:54:19 +03:00
allegroai
74acaa31df Add explicit refresh interval to ES mappings
Fix queue tests
2023-07-26 18:54:02 +03:00
allegroai
21ed8559bf Fix worker keys not returned in queues.get_all_ex 2023-07-26 18:51:20 +03:00
allegroai
3927604648 Add task names to events.get_single_value_metrics endpoint response 2023-07-26 18:50:53 +03:00
allegroai
f7dcbd96ec Fix deleting model events
Add delete_external_artifacts parameter to projects.delete endpoint
2023-07-26 18:49:54 +03:00
allegroai
5950b81f0b Fix child tasks count for top level pipeline and dataset projects 2023-07-26 18:49:12 +03:00
allegroai
1e51e2e221 Allow projection of more than 500 items 2023-07-26 18:46:58 +03:00
allegroai
4c98b87554 Fix issues with new dependencies 2023-07-26 18:46:28 +03:00
allegroai
c196043d2a Add max_download_items to users.get_current_user endpoint response 2023-07-26 18:45:42 +03:00
allegroai
752020c66a Update API version to 2.26 2023-07-26 18:44:20 +03:00
allegroai
6885d07462 Write UTF-8 BOM into csv download file 2023-07-26 18:43:38 +03:00
allegroai
00552da1b0 Requests context is not needed any more 2023-07-26 18:43:09 +03:00
allegroai
eebe2eeffc Update requirements 2023-07-26 18:42:26 +03:00
allegroai
bc2fe28bdd Add field_mappings to organizations download endpoints 2023-07-26 18:39:41 +03:00
allegroai
ed86750b24 Add scalar field type to jsonmodels 2023-07-26 18:39:06 +03:00
allegroai
6df69afb25 Support "__$or" condition on projects children filtering 2023-07-26 18:38:41 +03:00
allegroai
3f22423c3f Support paging in projects.get_model_metadata_values and get_hyperparam_values endpoints 2023-07-26 18:38:11 +03:00
allegroai
3ad636c468 Exported csv file name now contains the project name (including non-ascii names) 2023-07-26 18:37:20 +03:00
allegroai
5c80336aa9 Project delete and validate_delete now analyses and presents info for datasets and pipelines 2023-07-26 18:36:45 +03:00
allegroai
5cd59ea6e3 Fix csv export handling "," in fields 2023-07-26 18:35:31 +03:00
allegroai
5d3ba4fa73 Fix events.get_multitask_plots to retrieve last iterations per each task metric separately 2023-07-26 18:34:30 +03:00
allegroai
42556c8dbb Pipelines children query now looks for pipeline projects and not tasks 2023-07-26 18:33:35 +03:00
allegroai
dbe1c6f00f Allow configuring multi-plots batch size 2023-07-26 18:33:10 +03:00
allegroai
a17485b1bd Allow dequeueing a deleted task 2023-07-26 18:32:32 +03:00
allegroai
a2b9fed92d Make sure that scroll parameters are ignored when downloading tasks 2023-07-26 18:31:56 +03:00
allegroai
ff34da3c88 Add organization.download_for_get_all endpoint 2023-07-26 18:31:20 +03:00
allegroai
5239755066 Support include_subprojects flag in reports.get_all_ex endpoint 2023-07-26 18:30:34 +03:00
allegroai
8061dfedbb Fix NewListBucketsHelper backwards compatibility 2023-07-26 18:27:51 +03:00
allegroai
011164ce9b Support __$and condition for excluded terms in get_all_ex endpoints list filters 2023-07-26 18:26:49 +03:00
allegroai
8135cf5258 Add include_subprojects to tasks/models.get_all endpoints
Fix escaping metadata for tasks, models and queues
2023-07-26 18:24:49 +03:00
allegroai
a83a932e84 Add pipelines.delete_runs endpoint 2023-07-26 18:23:05 +03:00
allegroai
db021f2863 Add workers.get_count endpoint 2023-07-26 18:21:52 +03:00
allegroai
1b650b1689 Add projects.get_user_names endpoint 2023-07-26 18:21:16 +03:00
allegroai
14d18a7aba Remove obsolete duration field 2023-07-26 18:19:41 +03:00
Olivier Girardot
a7ed46979f Fix handling of the subpaths with nginx templating (#204)
Co-authored-by: ogirardot <olivier.girardot@malt.com>
2023-07-02 16:12:29 +03:00
allegroai
452f606889 Version bump to v1.11 2023-05-25 19:40:07 +03:00
allegroai
fc47ccbf09 Add default services agent user 2023-05-25 19:39:53 +03:00
allegroai
0206811342 Improve empty database check during startup 2023-05-25 19:39:17 +03:00
allegroai
a3ac1049a3 Update ClearML SDK dependency 2023-05-25 19:38:48 +03:00
allegroai
8488f63a3a Add fileserver URL prefixes for async deletion 2023-05-25 19:38:07 +03:00
allegroai
9206a7c57d Schedule external file URLs for deletion on models deletion 2023-05-25 19:36:28 +03:00
allegroai
0c37ced2a1 Fix model Id handling when deleting models for tasks 2023-05-25 19:35:18 +03:00
allegroai
b22f26129e Update requirements 2023-05-25 19:34:19 +03:00
allegroai
d8b998ebd8 Bump API version to 2.25 2023-05-25 19:33:37 +03:00
allegroai
741fa84b52 Fix projects own_tasks does not take task state filter into account 2023-05-25 19:32:52 +03:00
allegroai
d9579891c8 Return only reports from the .reports projects in reports.get_all_ex 2023-05-25 19:31:05 +03:00
allegroai
900414d0de Add option to echo ping payload 2023-05-25 19:30:13 +03:00
allegroai
5449b332d2 Support reports from the root project in reports.get_all_ex 2023-05-25 19:29:46 +03:00
allegroai
875f4b9536 Fix task dequeue will changes status for un-queued/running tasks 2023-05-25 19:28:49 +03:00
allegroai
95b8f22899 Add CLEARML_FILES_HOST to async_delete in windows 2023-05-25 19:27:40 +03:00
allegroai
4058fb9ce5 Migrate to python 3.9 bullseye docker images
Update Mongo driver version
2023-05-25 19:27:14 +03:00
allegroai
cf8e847ed3 Switch to new redis version 2023-05-25 19:22:39 +03:00
allegroai
755cc803d9 Add remove_from_all_queues parameter to tasks.dequeue/dequeue_many endpoints 2023-05-25 19:22:10 +03:00
allegroai
3729afe014 Optimize queues.get_next_task to retrieve required task fields only 2023-05-25 19:21:24 +03:00
allegroai
dff2ed34e8 Support receiving mixed events for both locked and unlocked tasks and models events.add_batch 2023-05-25 19:20:35 +03:00
allegroai
de9651d761 Allow mixing Model and task events in the same events batch 2023-05-25 19:19:45 +03:00
allegroai
818496236b Support filtering by children tags in projects.get_all_ex 2023-05-25 19:19:10 +03:00
allegroai
e99817b28b Task reports can now return single value metrics 2023-05-25 19:18:24 +03:00
allegroai
58465fbc17 Model events are fully supported 2023-05-25 19:17:40 +03:00
allegroai
2e4e060a82 Task move forward/backwards in queue is now atomic 2023-05-25 19:16:33 +03:00
allegroai
5c5d9b6434 Fix numeric hyperparam values are not sorted lexicographically with descending sort order 2023-05-25 19:15:59 +03:00
allegroai
4291ad682a Support filtering by task name in projects.get_task_parent 2023-05-25 19:15:26 +03:00
allegroai
4c22757002 Fix task that is not in queue but has 'queued' status can't be dequeued 2023-05-25 19:14:25 +03:00
allegroai
6e777e80b8 Cleaned up unit tests 2023-05-25 19:13:10 +03:00
138 changed files with 5397 additions and 1864 deletions

View File

@@ -27,7 +27,7 @@
24: ["not_public_object", "object is not public"]
# Auth / Login
75: ["invalid_access_key", "access key not found for user"]
75: ["invalid_access_key", "access key not found"]
# Tasks
100: ["task_error", "general task error"]
@@ -53,6 +53,9 @@
# Reports
150: ["operation_supported_on_reports_only", "passed task is not report"]
# Pipelines
160: ["cannot_remove_all_runs", "at least one pipeline run should be left"]
# Models
200: ["model_error", "general task error"]
201: ["invalid_model_id", "invalid model id"]
@@ -73,12 +76,14 @@
402: ["project_has_tasks", "project has associated tasks"]
403: ["project_not_found", "project not found"]
405: ["project_has_models", "project has associated models"]
406: ["project_has_datasets", "project has associated non-empty datasets"]
407: ["invalid_project_name", "invalid project name"]
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
413: ["project_has_pipelines", "project has associated pipelines with active controllers"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]

View File

@@ -61,6 +61,13 @@ class ListField(fields.ListField):
item.validate()
class ScalarField(fields.BaseField):
"""String field."""
types = (str, int, float, bool)
class DictField(fields.BaseField):
types = (dict,)

View File

@@ -13,6 +13,14 @@ from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum
class TaskRequest(Base):
task: str = StringField(required=True)
class ModelRequest(Base):
model: str = StringField(required=True)
class HistogramRequestBase(Base):
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@@ -29,6 +37,11 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
model_events: bool = BoolField(default=False)
class GetMetricsAndVariantsRequest(Base):
task: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str,
@@ -41,6 +54,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
)
],
)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
@@ -50,6 +64,12 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class LegacyMetricEventsRequest(TaskRequest):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
@@ -58,7 +78,14 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
model_events: bool = BoolField(default=False)
class VectorMetricsIterHistogramRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
model_events: bool = BoolField(default=False)
class GetVariantSampleRequest(Base):
@@ -109,6 +136,11 @@ class TaskEventsRequest(TaskEventsRequestBase):
model_events: bool = BoolField(default=False)
class LegacyLogEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.desc)
scroll_id: str = StringField()
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
@@ -148,13 +180,30 @@ class MultiTasksRequestBase(Base):
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class LegacyMultiTaskEventsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1, validators=validators.Min(1))
scroll_id: str = StringField()
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
@@ -164,6 +213,14 @@ class TaskPlotsRequest(Base):
model_events: bool = BoolField(default=False)
class GetScalarMetricDataRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()

View File

@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
class GetSupportedModesRequest(Base):
state = StringField(help_text="ASCII base64 encoded application state")
callback_url_prefix = StringField()
pass
# state = StringField(help_text="ASCII base64 encoded application state")
# callback_url_prefix = StringField()
class BasicGuestMode(Base):

View File

@@ -42,12 +42,29 @@ class ModelRequest(models.Base):
model = fields.StringField(required=True)
class TaskRequest(models.Base):
task = fields.StringField(required=True)
class UpdateForTaskRequest(TaskRequest):
uri = fields.StringField()
iteration = fields.IntField()
override_model_id = fields.StringField()
class UpdateModelRequest(ModelRequest):
task = fields.StringField()
iteration = fields.IntField()
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class ModelsDeleteManyRequest(BatchRequest):
force = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class PublishModelRequest(ModelRequest):

View File

@@ -1,6 +1,11 @@
from jsonmodels import fields, models
from enum import auto
from typing import Sequence
from apiserver.apimodels import DictField
from jsonmodels import fields, models
from jsonmodels.validators import Length
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
from apiserver.utilities.stringenum import StringEnum
class Filter(models.Base):
@@ -23,3 +28,35 @@ class EntitiesCountRequest(models.Base):
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
class EntityType(StringEnum):
task = auto()
model = auto()
class ValueMapping(models.Base):
key = ScalarField(nullable=True)
value = ScalarField(nullable=True)
class FieldMapping(models.Base):
field = fields.StringField(required=True)
name = fields.StringField()
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
class PrepareDownloadForGetAllRequest(models.Base):
entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False)
only_fields = fields.ListField(
items_types=[str], validators=[Length(1)], required=True
)
field_mappings: Sequence[FieldMapping] = fields.ListField(
items_types=[FieldMapping], validators=[Length(1)], required=True
)
class DownloadForGetAllRequest(models.Base):
prepare_id = fields.StringField(required=True)

View File

@@ -1,4 +1,5 @@
from jsonmodels import models, fields
from jsonmodels.validators import Length
from apiserver.apimodels import ListField
@@ -8,12 +9,13 @@ class Arg(models.Base):
value = fields.StringField(required=True)
class DeleteRunsRequest(models.Base):
project = fields.StringField(required=True)
ids = ListField([str], required=True, validators=[Length(1)])
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)
verify_watched_queue = fields.BoolField(default=False)

View File

@@ -1,10 +1,11 @@
from enum import Enum
from enum import Enum, auto
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
from apiserver.utilities.stringenum import StringEnum
class ProjectRequest(models.Base):
@@ -22,6 +23,7 @@ class MoveRequest(ProjectRequest):
class DeleteRequest(ProjectRequest):
force = fields.BoolField(default=False)
delete_contents = fields.BoolField(default=False)
delete_external_artifacts = fields.BoolField(default=True)
class ProjectOrNoneRequest(models.Base):
@@ -29,6 +31,11 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True)
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
ids = fields.ListField(str)
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -39,23 +46,38 @@ class ProjectTagsRequest(TagsRequest):
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
projects = fields.ListField(items_types=[str, type(None)])
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest):
tasks_state = ActualEnumField(EntityVisibility)
task_name = fields.StringField()
class ProjectHyperparamValuesRequest(MultiProjectRequest):
class EntityTypeEnum(StringEnum):
task = auto()
model = auto()
class ProjectUserNamesRequest(MultiProjectRequest):
entity = ActualEnumField(EntityTypeEnum, default=EntityTypeEnum.task)
class MultiProjectPagedRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
pattern = fields.StringField()
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectChildrenType(Enum):
@@ -77,3 +99,5 @@ class ProjectsGetRequest(models.Base):
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)
children_tags_filter = DictField()

View File

@@ -57,15 +57,27 @@ class EventsRequest(Base):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class PlotEventsRequest(EventsRequest):
last_iters_per_task_metric: bool = BoolField(default=True)
class ScalarMetricsIterHistogram(HistogramRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class SingleValueMetrics(Base):
pass
class GetTasksDataRequest(Base):
debug_images: EventsRequest = EmbeddedField(EventsRequest)
plots: EventsRequest = EmbeddedField(EventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
ScalarMetricsIterHistogram
)
single_value_metrics: SingleValueMetrics = EmbeddedField(SingleValueMetrics)
allow_public = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetAllRequest(Base):

View File

@@ -6,6 +6,10 @@ class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class GetConfigRequest(Base):
path = StringField()
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()

View File

@@ -96,6 +96,11 @@ class UpdateRequest(TaskUpdateRequest):
status_message = StringField(default="")
class DequeueRequest(UpdateRequest):
remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
@@ -274,6 +279,11 @@ class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest):
remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
@@ -323,3 +333,8 @@ class DeleteModelsRequest(TaskRequest):
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)
class UpdateTagsRequest(BatchRequest):
add_tags = ListField([str])
remove_tags = ListField([str])

View File

@@ -4,6 +4,10 @@ from jsonmodels.models import Base
from apiserver.apimodels import DictField
class UserRequest(Base):
user = StringField(required=True)
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)

View File

@@ -12,9 +12,8 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
from apiserver.config_repo import config
class WorkerRequest(Base):
@@ -24,7 +23,10 @@ class WorkerRequest(Base):
class RegisterRequest(WorkerRequest):
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min)
timeout = IntField(
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
)
""" registration timeout in seconds (default is 10min) """
queues = ListField(six.string_types) # list of queues this worker listens to
@@ -104,6 +106,10 @@ class GetAllResponse(Base):
workers = ListField(WorkerResponseEntry)
class GetCountRequest(GetAllRequest):
last_seen = IntField(default=0)
class StatsBase(Base):
worker_ids = ListField(str)

View File

@@ -5,7 +5,6 @@ import zlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import elasticsearch
@@ -24,12 +23,15 @@ from apiserver.bll.event.event_common import (
get_metric_variants_condition,
uncompress_plot,
get_max_metric_and_variant_counts,
PlotFields,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.task.utils import get_many_tasks_for_writing
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
@@ -41,26 +43,23 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.service_repo.auth import Identity
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
EVENT_TYPES: Set[str] = set(et.value for et in EventType if et != EventType.all)
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
MAX_LONG = 2**63 - 1
MIN_LONG = -(2**63)
log = config.logger(__file__)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
async_delete_threshold = config.get(
"services.tasks.async_events_delete_threshold", 100_000
)
class EventBLL(object):
@@ -102,47 +101,96 @@ class EventBLL(object):
return self._metrics
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
def _get_valid_entities(
company_id, ids: Mapping[str, bool], identity: Identity, model=False
) -> Set:
"""Verify that task or model exists and can be updated"""
if not ids:
return set()
with translate_errors_context():
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
allow_locked = {id_ for id_, allowed in ids.items() if allowed}
not_locked = {id_ for id_, allowed in ids.items() if not allowed}
res = set()
allow_locked_q = Q()
not_locked_q = (
Q(ready__ne=True) if model else Q(status__nin=LOCKED_TASK_STATUSES)
)
for requested_ids, locked_q in (
(allow_locked, allow_locked_q),
(not_locked, not_locked_q),
):
if not requested_ids:
continue
@staticmethod
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
"""Verify that task exists and can be updated"""
if not model_ids:
return set()
query = Q(id__in=requested_ids) & locked_q
if model:
ids = Model.objects(query & Q(company=company_id)).scalar("id")
else:
ids = {
t.id
for t in get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=query,
only=("id",),
throw_on_forbidden=False,
)
}
with translate_errors_context():
query = Q(id__in=model_ids, company=company_id)
if not allow_locked_models:
query &= Q(ready__ne=True)
res = Model.objects(query).only("id")
return {r.id for r in res}
res.update(ids)
return res
def add_events(
self, company_id, events, worker, allow_locked=False
self,
company_id: str,
identity: Identity,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False)
user_id = identity.user
task_ids = {}
model_ids = {}
for event in events:
if event.get("model_event", model_events) != model_events:
if event.get("model_event", False):
model = event.pop("model", None)
if model is not None:
event["task"] = model
entity_ids = model_ids
else:
event["model_event"] = False
entity_ids = task_ids
id_ = event.get("task")
allow_locked = event.pop("allow_locked", False)
if not id_:
continue
allowed_for_entity = entity_ids.get(id_)
if allowed_for_entity is None:
entity_ids[id_] = allow_locked
elif allowed_for_entity != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events"
)
if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
f"Inconsistent allow_locked setting in the passed events for {id_}"
)
found_in_both = set(task_ids).intersection(set(model_ids))
if found_in_both:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_entities(
company_id, ids=model_ids, identity=identity, model=True
)
valid_tasks = self._get_valid_entities(
company_id, ids=task_ids, identity=identity
)
actions: List[dict] = []
task_or_model_ids = set()
used_task_ids = set()
used_model_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
3, dict
@@ -152,28 +200,6 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked,
)
entity_name = "task"
for event in events:
# remove spaces from event type
@@ -187,7 +213,8 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
if model_events and event_type == EventType.task_log.value:
model_event = event["model_event"]
if model_event and event_type == EventType.task_log.value:
errors_per_type[f"Task log events are not supported for models"] += 1
continue
@@ -196,8 +223,12 @@ class EventBLL(object):
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_or_model_id not in valid_entities:
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
if (model_event and task_or_model_id not in valid_models) or (
not model_event and task_or_model_id not in valid_tasks
):
errors_per_type[
f"Invalid {'model' if model_event else 'task'} id {task_or_model_id}"
] += 1
continue
event["type"] = event_type
@@ -232,7 +263,6 @@ class EventBLL(object):
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type)
es_action = {
@@ -241,31 +271,33 @@ class EventBLL(object):
"_source": event,
}
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
# for "log" events, don't assign custom _id - whatever is sent, is written (not overwritten)
if event_type != EventType.task_log.value:
es_action["_id"] = self._get_event_id(event)
else:
es_action["_id"] = dbutils.id()
task_or_model_ids.add(task_or_model_id)
if (
iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_or_model_id] = max(
iter, task_iteration[task_or_model_id]
)
if not model_events:
if model_event:
used_model_ids.add(task_or_model_id)
else:
used_task_ids.add(task_or_model_id)
self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event,
last_events=task_last_events[task_or_model_id],
event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
)
actions.append(es_action)
@@ -303,31 +335,40 @@ class EventBLL(object):
else:
errors_per_type["Error when indexing events batch"] += 1
if not model_events:
remaining_tasks = set()
now = datetime.utcnow()
for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_or_model_id,
now=now,
iter_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
last_events=task_last_events.get(task_or_model_id),
)
now = datetime.utcnow()
for model_id in used_model_ids:
ModelBLL.update_statistics(
company_id=company_id,
user_id=user_id,
model_id=model_id,
last_update=now,
last_iteration_max=task_iteration.get(model_id),
last_scalar_events=task_last_scalar_events.get(model_id),
)
remaining_tasks = set()
for task_id in used_task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
user_id=user_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if not updated:
remaining_tasks.add(task_or_model_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks,
company_id=company_id,
user_id=user_id,
last_update=now,
)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
@@ -457,9 +498,10 @@ class EventBLL(object):
def _update_task(
self,
company_id,
task_id,
now,
company_id: str,
user_id: str,
task_id: str,
now: datetime,
iter_max=None,
last_scalar_events=None,
last_events=None,
@@ -475,8 +517,9 @@ class EventBLL(object):
return False
return TaskBLL.update_statistics(
task_id,
company_id,
task_id=task_id,
company_id=company_id,
user_id=user_id,
last_update=now,
last_iteration_max=iter_max,
last_scalar_events=last_scalar_events,
@@ -484,7 +527,9 @@ class EventBLL(object):
)
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.event_id_fields if field in event)
id_values = (
str(event[field]) for field in self.event_id_fields if field in event
)
return hashlib.md5("-".join(id_values).encode()).hexdigest()
def scroll_task_events(
@@ -556,11 +601,10 @@ class EventBLL(object):
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type,
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // last_iterations_per_plot)
@@ -586,7 +630,7 @@ class EventBLL(object):
"events": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": last_iterations_per_plot
"size": last_iterations_per_plot,
}
}
},
@@ -597,11 +641,7 @@ class EventBLL(object):
}
with translate_errors_context():
es_response = search_company_events(
body=es_req,
ignore=404,
**search_args,
)
es_response = search_company_events(body=es_req, ignore=404, **search_args)
aggs_result = es_response.get("aggregations")
if not aggs_result:
@@ -614,9 +654,7 @@ class EventBLL(object):
for hit in variants_bucket["events"]["hits"]["hits"]
]
self.uncompress_plots(events)
return TaskEventsResult(
events=events, total_events=len(events)
)
return TaskEventsResult(events=events, total_events=len(events))
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
"""
@@ -633,9 +671,11 @@ class EventBLL(object):
return events, total_events, next_scroll_id
def get_debug_image_urls(
self, company_id: str, task_id: str, after_key: dict = None
self, company_id: str, task_ids: Sequence[str], after_key: dict = None
) -> Tuple[Sequence[str], Optional[dict]]:
if check_empty_data(self.es, company_id, EventType.metrics_image):
if not task_ids or check_empty_data(
self.es, company_id, EventType.metrics_image
):
return [], None
es_req = {
@@ -651,7 +691,10 @@ class EventBLL(object):
},
"query": {
"bool": {
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
"must": [
{"terms": {"task": task_ids}},
{"exists": {"field": "url"}},
]
}
},
}
@@ -669,9 +712,13 @@ class EventBLL(object):
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
self, company_id: str, task_ids: Sequence[str], scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
if scroll_id == self.empty_scroll:
if (
scroll_id == self.empty_scroll
or not task_ids
or check_empty_data(self.es, company_id, EventType.metrics_plot)
):
return [], None
if scroll_id:
@@ -686,7 +733,7 @@ class EventBLL(object):
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"task": task_ids}},
{"exists": {"field": PlotFields.source_urls}},
]
}
@@ -714,6 +761,7 @@ class EventBLL(object):
size=500,
scroll_id=None,
no_scroll=False,
last_iters_per_task_metric=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@@ -731,12 +779,7 @@ class EventBLL(object):
if not company_ids:
return TaskEventsResult()
task_ids = (
[task_id]
if isinstance(task_id, str)
else task_id
)
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metrics:
@@ -745,25 +788,47 @@ class EventBLL(object):
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
tasks_iters = self.get_last_iters(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
if last_iters_per_task_metric:
task_metric_iters = self.get_last_iters_per_metric(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"terms": {"iter": last_iters}},
]
}
}
}
for task, last_iters in tasks_iters.items()
if last_iters
]
for (task, metric), last_iters in task_metric_iters.items()
if last_iters
]
else:
tasks_iters = self.get_last_iters(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@@ -804,7 +869,8 @@ class EventBLL(object):
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
es_req = {
"size": 0,
@@ -858,8 +924,10 @@ class EventBLL(object):
}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
@@ -961,13 +1029,77 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters_per_metric(
self,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None,
) -> Mapping[Tuple[str, str], Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = [{"terms": {"task": task_ids}}]
if metrics:
must.append(get_metric_variants_condition(metrics))
max_tasks = min(len(task_ids), 1000)
max_metrics = 10_000 // (max_tasks * iters)
es_req: dict = {
"size": 0,
"aggs": {
"tasks": {
"terms": {"field": "task", "size": max_tasks},
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": max_metrics},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
}
},
}
},
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_ids,
event_type=event_type,
body=es_req,
)
if "aggregations" not in es_res:
return {}
return {
(tb["key"], mb["key"]): [ib["key"] for ib in mb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
for mb in tb["metrics"]["buckets"]
}
def get_last_iters(
self,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None
metrics: MetricVariants = None,
) -> Mapping[str, Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
@@ -983,11 +1115,12 @@ class EventBLL(object):
if metrics:
must.append(get_metric_variants_condition(metrics))
max_tasks = min(len(task_ids), 1000)
es_req: dict = {
"size": 0,
"aggs": {
"tasks": {
"terms": {"field": "task"},
"terms": {"field": "task", "size": max_tasks},
"aggs": {
"iters": {
"terms": {
@@ -1004,7 +1137,10 @@ class EventBLL(object):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_ids, event_type=event_type, body=es_req,
self.es,
company_id=company_ids,
event_type=event_type,
body=es_req,
)
if "aggregations" not in es_res:
@@ -1055,18 +1191,26 @@ class EventBLL(object):
return {"refresh": True}
def delete_task_events(
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
):
def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
if model:
self._validate_model_state(
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
company_id=company_id,
model_id=task_id,
allow_locked=allow_locked,
)
else:
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
async_delete = async_task_events_delete
if async_delete:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=[task_id],
)
if total <= async_delete_threshold:
async_delete = False
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context():
es_res = delete_company_events(
@@ -1124,14 +1268,23 @@ class EventBLL(object):
return es_res.get("deleted", 0)
def delete_multi_task_events(
self, company_id: str, task_ids: Sequence[str], async_delete=False
self, company_id: str, task_ids: Sequence[str], model=False
):
"""
Delete mutliple task events. No check is done for tasks write access
Delete multiple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
deleted = 0
with translate_errors_context():
async_delete = async_task_events_delete
if async_delete and len(task_ids) < 100:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=task_ids,
)
if total <= async_delete_threshold:
async_delete = False
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
@@ -1145,7 +1298,7 @@ class EventBLL(object):
deleted += es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
return deleted
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:

View File

@@ -64,7 +64,7 @@ def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
es_index = get_index_name(company_id, event_type.value)
if not es.indices.exists(es_index):
if not es.indices.exists(index=es_index):
return True
return False

View File

@@ -21,6 +21,7 @@ from apiserver.bll.event.event_common import (
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
@@ -161,7 +162,9 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, companies: TaskCompanies
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
@@ -179,7 +182,13 @@ class EventMetrics:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
pool.map(
partial(
self._get_task_single_value_metrics,
metric_variants=metric_variants,
),
companies.items(),
)
),
)
@@ -195,19 +204,19 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, tasks: Tuple[str, Sequence[str]]
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
) -> Sequence[dict]:
company_id, task_ids = tasks
must = [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
@@ -280,7 +289,8 @@ class EventMetrics:
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -366,7 +376,8 @@ class EventMetrics:
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -432,7 +443,9 @@ class EventMetrics:
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
cls,
task_id: str,
metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
@@ -451,12 +464,96 @@ class EventMetrics:
return {"bool": {"must": must}}
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
"""
For the requested tasks return reported metrics and variants
"""
tasks_ids = {
company: [t.id for t in tasks]
for company, tasks in companies.items()
}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
companies_res: Sequence = list(
pool.map(
partial(
self._get_multi_task_metrics,
event_type=event_type,
),
tasks_ids.items(),
)
)
if len(companies_res) == 1:
return companies_res[0]
res = defaultdict(set)
for c_res in companies_res:
for m, vars_ in c_res.items():
res[m].update(vars_)
return {
k: list(v)
for k, v in res.items()
}
def _get_multi_task_metrics(
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
) -> Mapping[str, list]:
company_id, task_ids = company_tasks
if check_empty_data(self.es, company_id, event_type):
return {}
search_args = dict(
es=self.es,
company_id=company_id,
event_type=event_type,
)
query = QueryBuilder.terms("task", task_ids)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
}
}
}
},
}
es_res = search_company_events(
body=es_req,
**search_args,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
return {
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
for mb in aggs_result["metrics"]["buckets"]
}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
For the requested tasks return reported metrics per task
"""
if check_empty_data(self.es, company_id, event_type):
return {}

View File

@@ -64,13 +64,13 @@ class EventsIterator:
self,
event_type: EventType,
company_id: str,
task_id: str,
task_ids: Sequence[str],
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
query, _ = self._get_initial_query_and_must(task_ids, metric_variants)
es_req = {
"query": query,
}
@@ -100,7 +100,7 @@ class EventsIterator:
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
query, must = self._get_initial_query_and_must([task_id], metric_variants)
# retrieve the next batch of events
es_req = {
@@ -158,14 +158,14 @@ class EventsIterator:
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
task_ids: Sequence[str], metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
query = {"terms": {"task": task_ids}}
must = [query]
else:
must = [
{"term": {"task": task_id}},
{"terms": {"task": task_ids}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}

View File

@@ -86,7 +86,7 @@ class MetricEventsIterator:
task_id: company_id
for task_id, company_id in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
self.es, company_id=company_id, event_type=self.event_type
)
}
if not companies:

View File

@@ -5,14 +5,18 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo.auth import Identity
from .metadata import Metadata
class ModelBLL:
event_bll = None
@classmethod
def get_company_model_by_id(
cls, company_id: str, model_id: str, only_fields=None
@@ -28,11 +32,7 @@ class ModelBLL:
@staticmethod
def assert_exists(
company_id,
model_ids,
only=None,
allow_public=False,
return_models=True,
company_id, model_ids, only=None, allow_public=False, return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
@@ -58,14 +58,15 @@ class ModelBLL:
cls,
model_id: str,
company_id: str,
user_id: str,
identity: Identity,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
user_id = identity.user
published_task = None
if model.task and publish_task_func:
task = (
@@ -75,18 +76,25 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, user_id, force_publish_task
model.task, company_id, identity, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
now = datetime.utcnow()
updated = model.update(
upsert=False,
ready=True,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
return updated, published_task
@classmethod
def delete_model(
cls, model_id: str, company_id: str, force: bool
cls, model_id: str, company_id: str, user_id: str, force: bool, delete_external_artifacts: bool = True,
) -> Tuple[int, Model]:
model = cls.get_company_model_by_id(
company_id=company_id,
@@ -112,49 +120,88 @@ class ModelBLL:
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
if task:
now = datetime.utcnow()
if task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
"last_changed_by": user_id,
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
else:
task.update(
pull__models__output__model=model_id,
set__last_change=now,
set__last_changed_by=user_id,
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", True
)
if delete_external_artifacts:
from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
_schedule_for_delete,
)
urls = set()
urls.update(collect_debug_image_urls(company_id, model_id))
urls.update(collect_plot_image_urls(company_id, model_id))
if model.uri:
urls.add(model.uri)
if urls:
_schedule_for_delete(
task_id=model_id,
company=company_id,
user=user_id,
urls=urls,
can_delete_folders=False,
)
if not cls.event_bll:
from apiserver.bll.event import EventBLL
cls.event_bll = EventBLL()
cls.event_bll.delete_task_events(company_id, model_id, allow_locked=True, model=True)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
def archive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return unarchived
@@ -179,12 +226,43 @@ class ModelBLL:
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
{"$project": {"labels_count": 1}},
]
)
return {
r.pop("_id"): r
for r in result
return {r.pop("_id"): r for r in result}
@staticmethod
def update_statistics(
company_id: str,
user_id: str,
model_id: str,
last_update: datetime = None,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
last_update = last_update or datetime.utcnow()
updates = {
"last_update": datetime.utcnow(),
"last_change": last_update,
"last_changed_by": user_id,
}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)
raw_updates = {}
if last_scalar_events is not None:
raw_updates = {}
if last_scalar_events is not None:
get_last_metric_updates(
task_id=model_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=updates,
model_events=True,
)
ret = Model.objects(id=model_id).update_one(**updates)
if ret and raw_updates:
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
return ret

View File

@@ -5,7 +5,6 @@ from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -87,13 +86,13 @@ class Metadata:
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
def escape_query_parameters(cls, call_data: dict) -> dict:
if not call_data:
return call_data
keys = list(call.data)
keys = list(call_data)
call_data = {
safe_key: call.data[key]
safe_key: call_data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}

View File

@@ -1,8 +1,11 @@
from collections import defaultdict
from enum import Enum
from typing import Sequence, Dict
from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@@ -22,6 +25,51 @@ class OrgBLL:
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags(
self,
company_id: str,
@@ -50,10 +98,10 @@ class OrgBLL:
return ret
def update_tags(
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, project, tags, system_tags)
tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)

View File

@@ -107,7 +107,7 @@ class _TagsCache:
return ret
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
"""
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
@@ -123,7 +123,7 @@ class _TagsCache:
if not fields:
return
self._delete_redis_keys(company_id, projects=[project], fields=fields)
self._delete_redis_keys(company_id, projects=projects, fields=fields)
def reset_tags(self, company_id: str, projects: Sequence[str]):
self._delete_redis_keys(

View File

@@ -14,16 +14,16 @@ from typing import (
Callable,
Mapping,
Any,
Union,
)
from boltons.iterutils import partition
from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.projects import ProjectChildrenType
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model import EntityVisibility, AttributedDocument, User
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
@@ -58,7 +58,7 @@ class ProjectBLL:
@classmethod
def merge_project(
cls, company, source_id: str, destination_id: str
cls, company: str, source_id: str, destination_id: str
) -> Tuple[int, int, Set[str]]:
"""
Move all the tasks and sub projects from the source project to the destination
@@ -315,11 +315,12 @@ class ProjectBLL:
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
extra = {}
if hasattr(entity_cls, "last_change"):
extra["set__last_change"] = datetime.utcnow()
if hasattr(entity_cls, "last_changed_by"):
extra["set__last_changed_by"] = user
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
@@ -340,6 +341,17 @@ class ProjectBLL:
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def project_task_fields():
return {
"$project": {
"project": 1,
"status": 1,
"system_tags": 1,
"started": 1,
"completed": 1,
}
}
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
@@ -367,6 +379,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
"$group": {
@@ -515,6 +528,7 @@ class ProjectBLL:
users=users,
)
},
project_task_fields(),
ensure_valid_fields(),
{
# for each project
@@ -550,7 +564,10 @@ class ProjectBLL:
@classmethod
def get_dataset_stats(
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
cls,
company: str,
project_ids: Sequence[str],
users: Sequence[str] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
@@ -584,7 +601,9 @@ class ProjectBLL:
@staticmethod
def _get_projects_children(
project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str],
project_ids: Sequence[str],
search_hidden: bool,
allowed_ids: Sequence[str],
) -> Tuple[ProjectsChildren, Set[str]]:
child_projects = _get_sub_projects(
project_ids,
@@ -628,7 +647,9 @@ class ProjectBLL:
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids, search_hidden=True, allowed_ids=selected_project_ids,
project_ids,
search_hidden=True,
allowed_ids=selected_project_ids,
)
project_ids_with_children |= children_ids
@@ -901,6 +922,8 @@ class ProjectBLL:
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
children_type: ProjectChildrenType = None,
children_tags: Sequence[str] = None,
children_tags_filter: dict = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
@@ -921,15 +944,27 @@ class ProjectBLL:
query &= Q(user__in=users)
project_query = None
if children_tags_filter:
child_query = query & GetMixin.get_list_filter_query(
"tags", children_tags_filter
)
elif children_tags:
child_query = query & GetMixin.get_list_field_query("tags", children_tags)
else:
child_query = query
if children_type == ProjectChildrenType.dataset:
child_queries = {
Project: query
Project: child_query
& Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name)
}
elif children_type == ProjectChildrenType.pipeline:
child_queries = {Task: query & Q(system_tags__in=[pipeline_tag])}
child_queries = {
Project: child_query
& Q(system_tags__in=[pipeline_tag], basename__ne=pipelines_project_name)
}
elif children_type == ProjectChildrenType.report:
child_queries = {Task: query & Q(system_tags__in=[reports_tag])}
child_queries = {Task: child_query & Q(system_tags__in=[reports_tag])}
else:
project_query = query
child_queries = {entity_cls: query for entity_cls in cls.child_classes}
@@ -946,14 +981,12 @@ class ProjectBLL:
)
res = (
{p.id for p in Project.objects(project_query).only("id")}
if project_query
else set()
set(Project.objects(project_query).scalar("id")) if project_query else set()
)
for cls_, query_ in child_queries.items():
res |= set(
cls_.objects(query_).distinct(
field="parent" if cls_ is Project else "project"
field="id" if cls_ is Project else "project"
)
)
@@ -970,20 +1003,14 @@ class ProjectBLL:
return filtered_ids, selected_project_ids
@classmethod
def get_task_parents(
cls,
company_id: str,
projects: Sequence[str],
include_subprojects: bool,
@staticmethod
def _get_project_query(
company: str,
projects: Sequence,
include_subprojects: bool = True,
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
) -> Q:
query = get_company_or_none_constraint(company)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
@@ -996,6 +1023,25 @@ class ProjectBLL:
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
return query
@classmethod
def get_task_parents(
cls,
company_id: str,
projects: Sequence[str],
include_subprojects: bool,
state: Optional[EntityVisibility] = None,
name: str = None,
) -> Sequence[dict]:
"""
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = cls._get_project_query(
company_id, projects, include_subprojects=include_subprojects, state=state
)
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
return []
@@ -1003,23 +1049,37 @@ class ProjectBLL:
parents = Task.get_many_with_join(
company_id,
query=Q(id__in=parent_ids),
query_dict={"name": name} if name else None,
allow_public=True,
override_projection=("id", "name", "project.name"),
)
return sorted(parents, key=itemgetter("name"))
@classmethod
def get_entity_users(
cls,
company: str,
entity_cls: Type[Union[Task, Model]],
projects: Sequence[str],
include_subprojects: bool,
) -> Sequence[dict]:
query = cls._get_project_query(
company, projects, include_subprojects=include_subprojects
)
user_ids = entity_cls.objects(query).distinct(field="user")
if not user_ids:
return []
users = User.objects(id__in=user_ids).only("id", "name")
return [{"id": u.id, "name": u.name} for u in users]
@classmethod
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
query = cls._get_project_query(company, project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@@ -1029,10 +1089,7 @@ class ProjectBLL:
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
query = cls._get_project_query(company, project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod
@@ -1053,20 +1110,52 @@ class ProjectBLL:
if not filter_:
return conditions
or_conditions = []
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
if not (field_filter and isinstance(field_filter, (list, dict))):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
f"Non empty list or dictionary expected for the field: {field}"
)
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
conditions[field] = {
**({"$in": include} if include else {}),
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
}
if isinstance(field_filter, list):
if not all(isinstance(t, str) for t in field_filter):
raise errors.bad_request.ValidationError(
f"Only string values are allowed in the list filter: {field}"
)
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
op = helper.global_operator
db_query = {op: helper.actions}
else:
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
db_query = helper.db_query
for op, actions in db_query.items():
field_conditions = {}
for action, values in actions.items():
value = list(set(values)) if isinstance(values, list) else values
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)
if op == Q.OR and len(field_conditions) > 1:
or_conditions.append(
{
"$or": [
{field: {db_modifier: cond}}
for db_modifier, cond in field_conditions.items()
]
}
)
else:
conditions[field] = field_conditions
if or_conditions:
if len(or_conditions) == 1:
conditions = next(iter(or_conditions))
else:
conditions["$and"] = [c for c in or_conditions]
return conditions
@@ -1139,6 +1228,7 @@ class ProjectBLL:
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any] = None,
specific_state: Optional[EntityVisibility] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
@@ -1149,6 +1239,20 @@ class ProjectBLL:
if not project_ids:
return {}
if specific_state:
filter_ = filter_ or {}
system_tags_filter = filter_.get("system_tags", [])
archived = EntityVisibility.archived.value
non_archived = f"-{EntityVisibility.archived.value}"
if not any(t in system_tags_filter for t in (archived, non_archived)):
filter_ = {k: v for k, v in filter_.items()}
filter_["system_tags"] = [
archived
if specific_state == EntityVisibility.archived
else non_archived,
*system_tags_filter,
]
pipeline = [
{
"$match": cls.get_match_conditions(

View File

@@ -1,7 +1,9 @@
from collections import defaultdict
from datetime import datetime
from typing import Tuple, Set, Sequence
import attr
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
@@ -15,13 +17,19 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from .project_bll import ProjectBLL
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType, TaskStatus
from .project_bll import (
ProjectBLL,
pipeline_tag,
pipelines_project_name,
dataset_tag,
datasets_project_name,
reports_tag,
)
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -33,30 +41,83 @@ class DeleteProjectResult:
urls: TaskUrls = None
def _get_child_project_ids(
project_id: str,
) -> Tuple[Sequence[str], Sequence[str], Sequence[str]]:
project_ids = _ids_with_children([project_id])
pipeline_ids = list(
Project.objects(
id__in=project_ids,
system_tags__in=[pipeline_tag],
basename__ne=pipelines_project_name,
).scalar("id")
)
dataset_ids = list(
Project.objects(
id__in=project_ids,
system_tags__in=[dataset_tag],
basename__ne=datasets_project_name,
).scalar("id")
)
return project_ids, pipeline_ids, dataset_ids
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
ret = {}
for cls in ProjectBLL.child_classes:
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in ProjectBLL.child_classes:
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
name = f"non_archived_{cls.__name__.lower()}s"
if not is_pipeline:
ret[name] = cls.objects(**query).count()
else:
ret[name] = (
cls.objects(**query, type=TaskType.controller).count()
if cls == Task
else 0
if pipeline_ids:
pipelines_with_active_controllers = Task.objects(
project__in=pipeline_ids,
type=TaskType.controller,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["pipelines"] = len(pipelines_with_active_controllers)
else:
ret["pipelines"] = 0
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["datasets"] = len(datasets_with_data)
else:
ret["datasets"] = 0
project_ids = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
if project_ids:
in_project_query = Q(project__in=project_ids)
for cls in (Task, Model):
query = (
in_project_query & Q(system_tags__nin=[reports_tag])
if cls is Task
else in_project_query
)
ret[f"{cls.__name__.lower()}s"] = cls.objects(query).count()
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
query & Q(system_tags__nin=[EntityVisibility.archived.value])
).count()
ret["reports"] = Task.objects(
in_project_query & Q(system_tags__in=[reports_tag])
).count()
ret["non_archived_reports"] = Task.objects(
in_project_query
& Q(
system_tags__in=[reports_tag],
system_tags__nin=[EntityVisibility.archived.value],
)
).count()
else:
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = 0
ret[f"non_archived_{cls.__name__.lower()}s"] = 0
ret["reports"] = 0
ret["non_archived_reports"] = 0
return ret
@@ -67,7 +128,7 @@ def delete_project(
project_id: str,
force: bool,
delete_contents: bool,
delete_external_artifacts=True,
delete_external_artifacts: bool,
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
@@ -76,40 +137,58 @@ def delete_project(
raise errors.bad_request.InvalidProjectId(id=project_id)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
"services.async_urls_delete.enabled", True
)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
project_ids, pipeline_ids, dataset_ids = _get_child_project_ids(project_id)
if not force:
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
if not is_pipeline:
if pipeline_ids:
active_controllers = Task.objects(
project__in=pipeline_ids,
type=TaskType.controller,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if active_controllers:
raise errors.bad_request.ProjectHasPipelines(
"please archive all the controllers or use force=true",
id=project_id,
)
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if datasets_with_data:
raise errors.bad_request.ProjectHasDatasets(
"please delete all the dataset versions or use force=true",
id=project_id,
)
regular_projects = list(set(project_ids) - set(pipeline_ids) - set(dataset_ids))
if regular_projects:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(**query).only("id")
non_archived = cls.objects(
project__in=regular_projects,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
else:
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
if non_archived:
raise errors.bad_request.ProjectHasTasks(
"please archive all the runs inside the project", id=project_id
)
raise error("use force=true", id=project_id)
if not delete_contents:
disassociated = defaultdict(int)
for cls in ProjectBLL.child_classes:
disassociated[cls] = cls.objects(project__in=project_ids).update(project=None)
disassociated[cls] = cls.objects(project__in=project_ids).update(
project=None
)
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
@@ -138,7 +217,9 @@ def delete_project(
return res, affected
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
def _delete_tasks(
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
@@ -149,14 +230,24 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
if not tasks:
return 0, set(), set()
task_ids = {t.id for t in tasks}
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
task_ids = list({t.id for t in tasks})
now = datetime.utcnow()
Task.objects(parent__in=task_ids, project__nin=projects).update(
parent=None,
last_change=now,
last_changed_by=user,
)
Model.objects(task__in=task_ids, project__nin=projects).update(
task=None,
last_change=now,
last_changed_by=user,
)
event_urls, artifact_urls = set(), set()
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
company, task_ids
)
artifact_urls = set()
for task in tasks:
event_urls.update(collect_debug_image_urls(company, task.id))
event_urls.update(collect_plot_image_urls(company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls.update(
{
@@ -166,15 +257,13 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
}
)
event_bll.delete_multi_task_events(
company, list(task_ids), async_delete=async_events_delete
)
event_bll.delete_multi_task_events(company, task_ids)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(
company: str, projects: Sequence[str]
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
@@ -185,39 +274,54 @@ def _delete_models(
return 0, set(), set()
model_ids = list({m.id for m in models})
deleted = "__DELETED__"
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
update={"$set": {"models.input.$[elem].model": deleted}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
now = datetime.utcnow()
# update published tasks
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
"status": TaskStatus.published,
},
update={
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
"last_changed_by": user,
}
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
# update unpublished tasks
Task.objects(
id__in=model_tasks,
project__nin=projects,
status__ne=TaskStatus.published,
).update(
pull__models__output__model__in=model_ids,
set__last_change=now,
set__last_changed_by=user,
)
event_urls, model_urls = set(), set()
for m in models:
event_urls.update(collect_debug_image_urls(company, m.id))
event_urls.update(collect_plot_image_urls(company, m.id))
if m.uri:
model_urls.add(m.uri)
event_bll.delete_multi_task_events(
company, model_ids, async_delete=async_events_delete
event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
company, model_ids
)
model_urls = {m.uri for m in models if m.uri}
event_bll.delete_multi_task_events(company, model_ids, model=True)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -140,7 +140,12 @@ class ProjectQueries:
name: str,
include_subprojects: bool,
allow_public: bool = True,
pattern: str = None,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
page = max(0, page)
page_size = max(1, page_size)
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
@@ -160,7 +165,20 @@ class ProjectQueries:
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
redis_key = "_".join(
str(part)
for part in (
"hyperparam_values",
company_id,
"_".join(project_ids),
section,
name,
allow_public,
pattern,
page,
page_size,
)
)
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
@@ -172,19 +190,27 @@ class ProjectQueries:
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
match_condition = {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
if pattern:
match_condition["$expr"] = {
"$regexMatch": {
"input": f"${key_path}.value",
"regex": pattern,
"options": "i",
}
},
}
pipeline = [
{"$match": match_condition},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
@@ -209,13 +235,19 @@ class ProjectQueries:
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
ids: Sequence[str],
model_metrics: bool = False,
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
**({"_id": {"$in": ids}} if ids else {}),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
@@ -246,7 +278,8 @@ class ProjectQueries:
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
entity_cls = Model if model_metrics else Task
result = entity_cls.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
@@ -306,7 +339,11 @@ class ProjectQueries:
key: str,
include_subprojects: bool,
allow_public: bool = True,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
page = max(0, page)
page_size = max(1, page_size)
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
@@ -326,7 +363,7 @@ class ProjectQueries:
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}_{page}_{page_size}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
@@ -334,7 +371,6 @@ class ProjectQueries:
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
@@ -346,7 +382,8 @@ class ProjectQueries:
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,

View File

@@ -144,8 +144,8 @@ def _ids_with_children(project_ids: Sequence[str]) -> Sequence[str]:
"""
Return project ids with the ids of all the subprojects
"""
subprojects = Project.objects(path__in=project_ids).only("id")
return list({*project_ids, *(child.id for child in subprojects)})
children_ids = Project.objects(path__in=project_ids).scalar("id")
return list({*project_ids, *children_ids})
def _update_subproject_names(

View File

@@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from typing import Sequence, Optional, Tuple, Union
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -16,6 +16,8 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
log = config.logger(__file__)
MOVE_FIRST = "first"
MOVE_LAST = "last"
class QueueBLL(object):
@@ -150,10 +152,16 @@ class QueueBLL(object):
for item in queue.entries:
try:
task = Task.get_for_writing(
task = Task.get(
company=company_id,
id=item.task,
_only=["id", "status", "enqueue_status", "project"],
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
@@ -164,6 +172,7 @@ class QueueBLL(object):
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
force=True,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
@@ -234,6 +243,7 @@ class QueueBLL(object):
{
"name": w.id,
"ip": w.ip,
"key": w.key,
"task": w.task.to_struct() if w.task else None,
}
for w in queue_workers.get(item["id"], [])
@@ -319,46 +329,131 @@ class QueueBLL(object):
return len(entries_to_remove) if res else 0
def reposition_task(
self,
company_id: str,
queue_id: str,
task_id: str,
pos_func: Callable[[int], int],
self, company_id: str, queue_id: str, task_id: str, move_count: Union[int, str],
) -> int:
"""
Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
def get_queue_and_task_position():
q = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
return q, next(i for i, e in enumerate(q.entries) if e.task == task_id)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
new_position = pos_func(position)
with translate_errors_context():
queue, position = get_queue_and_task_position()
if move_count == MOVE_FIRST:
new_position = 0
elif move_count == MOVE_LAST:
new_position = len(queue.entries) - 1
else:
new_position = position + move_count
if new_position == position:
return new_position
if new_position != position:
entry = queue.entries[position]
query = dict(id=queue_id, company=company_id)
updated = Queue.objects(entries__task=task_id, **query).update_one(
pull__entries=entry, last_update=datetime.utcnow()
)
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
without_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$ne": ["$$entry.task", task_id]},
}
}
task_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$eq": ["$$entry.task", task_id]},
}
}
if move_count == MOVE_FIRST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [task_entry, without_entry]}
}
}
]
elif move_count == MOVE_LAST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [without_entry, task_entry]}
}
}
]
else:
operations = [
{
"$set": {
"new_pos": {
"$add": [
{"$indexOfArray": ["$entries.task", task_id]},
move_count,
]
},
"without_entry": without_entry,
"task_entry": task_entry,
}
},
{
"$set": {
"entries": {
"$switch": {
"branches": [
{
"case": {"$lte": ["$new_pos", 0]},
"then": {
"$concatArrays": [
"$task_entry",
"$without_entry",
]
},
},
{
"case": {
"$gte": [
"$new_pos",
{"$size": "$without_entry"},
]
},
"then": {
"$concatArrays": [
"$without_entry",
"$task_entry",
]
},
},
],
"default": {
"$concatArrays": [
{"$slice": ["$without_entry", "$new_pos"]},
"$task_entry",
{
"$slice": [
"$without_entry",
"$new_pos",
{"$size": "$without_entry"},
]
},
]
},
}
}
}
},
{"$unset": ["new_pos", "without_entry", "task_entry"]},
]
return new_position
updated = Queue.objects(
id=queue_id, company=company_id, entries__task=task_id
).update_one(__raw__=operations)
if not updated:
raise errors.bad_request.FailedAddingDuringReposition(task=task_id)
return get_queue_and_task_position()[1]
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(

View File

@@ -80,7 +80,7 @@ class QueueMetrics:
logged = 0
for q in queues:
queue_doc = make_doc(q)
self.es.index(index=es_index, body=queue_doc)
self.es.index(index=es_index, document=queue_doc)
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
redis.set(redis_key, json.dumps(queue_doc))
logged += 1

View File

@@ -8,8 +8,7 @@ from typing import Sequence, Optional
import dpath
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter, Retry
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.bll.util import get_server_uuid
@@ -255,6 +254,14 @@ class StatisticsReporter:
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$project": {
"last_worker": 1,
"last_update": 1,
"started": 1,
"last_iteration": 1,
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,

View File

@@ -1,6 +1,5 @@
from .task_bll import TaskBLL
from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
)

View File

@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -48,12 +49,14 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifacts = {
get_artifact_id(a): Artifact(**a)
@@ -64,18 +67,20 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifact_ids = [
get_artifact_id(a)
@@ -85,4 +90,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.service_repo.auth import Identity
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -31,7 +32,10 @@ class HyperParams:
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -63,7 +67,7 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
@@ -74,6 +78,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
with_param, without_param = iterutils.partition(
@@ -96,7 +101,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@@ -105,7 +110,7 @@ class HyperParams:
def edit_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
@@ -117,6 +122,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
update_cmds = dict()
@@ -135,7 +141,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@@ -163,7 +169,10 @@ class HyperParams:
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -209,13 +218,15 @@ class HyperParams:
def edit_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
update_cmds = dict()
configuration = {
@@ -228,22 +239,24 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -1,7 +1,7 @@
from datetime import timedelta, datetime
from time import sleep
from apiserver.bll.task import update_project_time
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.task.task import TaskStatus, Task
from apiserver.utilities.threads_manager import ThreadsManager
@@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
status_changed=now,
last_update=now,
last_change=now,
last_changed_by="__apiserver__",
)
if updated:
project_ids.add(task.project)

View File

@@ -7,11 +7,12 @@ from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.apierrors import errors, APIError
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -30,7 +31,11 @@ from apiserver.database.model.task.task import (
TaskModelTypes,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.database.model.queue import Queue
from apiserver.database.utils import (
get_company_or_none_constraint,
id as create_id,
)
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
@@ -38,8 +43,8 @@ from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
get_last_metric_updates,
)
log = config.logger(__file__)
@@ -53,30 +58,13 @@ class TaskBLL:
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
@@ -311,7 +299,7 @@ class TaskBLL:
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
projects=[new_task.project],
tags=updated_tags,
system_tags=updated_system_tags,
)
@@ -354,6 +342,7 @@ class TaskBLL:
def set_last_update(
task_ids: Collection[str],
company_id: str,
user_id: str,
last_update: datetime,
**extra_updates,
):
@@ -374,6 +363,7 @@ class TaskBLL:
upsert=False,
last_update=last_update,
last_change=last_update,
last_changed_by=user_id,
**updates,
)
return count
@@ -382,6 +372,7 @@ class TaskBLL:
def update_statistics(
task_id: str,
company_id: str,
user_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
@@ -412,81 +403,12 @@ class TaskBLL:
raw_updates = {}
if last_scalar_events is not None:
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values-to_add}__exists"] = True
task = Task.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [
condition,
iter_value,
f"${value_iteration_field}",
]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = (
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
)
if max_values:
if (
len(total_metrics) >= max_values
and metric not in total_metrics
):
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
get_last_metric_updates(
task_id=task_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=extra_updates,
)
if last_events is not None:
@@ -507,6 +429,7 @@ class TaskBLL:
ret = TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
user_id=user_id,
last_update=last_update,
**extra_updates,
)
@@ -515,6 +438,12 @@ class TaskBLL:
return ret
@staticmethod
def remove_task_from_all_queues(company_id: str, task_id: str) -> int:
return Queue.objects(company=company_id, entries__task=task_id).update(
pull__entries__task=task_id, last_update=datetime.utcnow()
)
@classmethod
def dequeue_and_change_status(
cls,
@@ -523,19 +452,28 @@ class TaskBLL:
user_id: str,
status_message: str,
status_reason: str,
remove_from_all_queues=False,
new_status=None,
):
try:
cls.dequeue(task, company_id)
except errors.bad_request.InvalidQueueOrTaskNotQueued:
cls.dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the queue was deleted
pass
if remove_from_all_queues:
cls.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
if task.status not in [TaskStatus.queued, TaskStatus.in_progress]:
return {"updated": 0}
return ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
new_status=new_status or task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
force=True,
).execute(enqueue_status=None)
@classmethod

View File

@@ -1,10 +1,10 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Set, Tuple
from typing import Sequence, Set, Tuple, Union
import attr
from boltons.iterutils import partition, bucketize, first
from boltons.iterutils import partition, bucketize, first, chunked_iter
from furl import furl
from mongoengine import NotUniqueError
from pymongo.errors import DuplicateKeyError
@@ -26,7 +26,6 @@ from apiserver.database.utils import id as db_id
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -69,52 +68,74 @@ class CleanupResult:
)
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
def collect_plot_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
urls = set()
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_ids=tasks, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
def collect_debug_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
after_key = None
urls = set()
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company, task_id=task_or_model, after_key=after_key,
)
urls.update(res)
if not after_key:
break
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
after_key = None
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_ids=tasks,
after_key=after_key,
)
urls.update(res)
if not after_key:
break
return urls
supported_storage_types = {
"https://": StorageType.fileserver,
"http://": StorageType.fileserver,
"s3://": StorageType.s3,
"azure://": StorageType.azure,
"gs://": StorageType.gs,
}
supported_storage_types.update(
{
p: StorageType.fileserver
for p in config.get(
"services.async_urls_delete.fileserver.url_prefixes",
["https://", "http://"],
)
}
)
def _schedule_for_delete(
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
company: str,
user: str,
task_id: str,
urls: Set[str],
can_delete_folders: bool,
) -> Set[str]:
urls_per_storage = bucketize(
urls,
@@ -196,7 +217,7 @@ def cleanup_task(
task, force
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
"services.async_urls_delete.enabled", True
)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls or delete_external_artifacts:
@@ -214,8 +235,13 @@ def cleanup_task(
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
now = datetime.utcnow()
if update_children:
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id,
last_change=now,
last_changed_by=user,
)
deleted_models = 0
updated_models = 0
@@ -223,37 +249,41 @@ def cleanup_task(
if not models:
continue
if delete_output_models and allow_delete:
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
for m_id in model_ids:
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
if model_ids:
if return_file_urls or delete_external_artifacts:
event_urls.update(collect_debug_image_urls(task.company, m_id))
event_urls.update(collect_plot_image_urls(task.company, m_id))
try:
event_bll.delete_task_events(
task.company,
m_id,
allow_locked=True,
model=True,
async_delete=async_events_delete,
)
except errors.bad_request.InvalidModelId as ex:
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
event_urls.update(collect_debug_image_urls(task.company, model_ids))
event_urls.update(collect_plot_image_urls(task.company, model_ids))
event_bll.delete_multi_task_events(
task.company,
model_ids,
model=True,
)
deleted_models += Model.objects(id__in=model_ids).delete()
deleted_models += Model.objects(id__in=list(model_ids)).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
Model.objects(id__in=list(in_use_model_ids)).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
continue
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id
task=deleted_task_id,
last_change=now,
last_changed_by=user,
)
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
Model.objects(id__in=[m.id for m in models]).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
event_bll.delete_task_events(
task.company, task.id, allow_locked=force, async_delete=async_events_delete
)
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
if delete_external_artifacts:
scheduled = _schedule_for_delete(
@@ -296,7 +326,8 @@ def verify_task_children_and_ouptuts(
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
Model.objects(task=task.id).only(*model_fields),
key=attrgetter("ready"),
)
if not force and published_models:
raise errors.bad_request.TaskCannotBeDeleted(

View File

@@ -7,9 +7,10 @@ from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
@@ -23,6 +24,8 @@ from apiserver.database.model.task.task import (
Execution,
DEFAULT_LAST_ITERATION,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
@@ -32,7 +35,7 @@ queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task],
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
@@ -41,19 +44,22 @@ def archive_task(
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
requires_write_access=True,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
@@ -61,6 +67,7 @@ def archive_task(
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
@@ -76,34 +83,67 @@ def archive_task(
def unarchive_task(
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
task_id: str,
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task, company_id=company_id, only=("id",), requires_write_access=True,
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
last_changed_by=identity.user,
)
def dequeue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
remove_from_all_queues: bool = False,
new_status=None,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if new_status and new_status not in get_options(TaskStatus):
raise errors.bad_request.ValidationError(f"Invalid task status: {new_status}")
# get the task without write access to make sure that it actually exists
task = Task.get(
id=task_id,
company=company_id,
_only=("id",),
include_public=True,
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"enqueue_status",
),
)
res = TaskBLL.dequeue_and_change_status(
task,
@@ -111,6 +151,8 @@ def dequeue_task(
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=remove_from_all_queues,
new_status=new_status,
)
return 1, res
@@ -118,7 +160,7 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
queue_id: str,
status_message: str,
status_reason: str,
@@ -143,11 +185,11 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
user_id = identity.user
if validate:
TaskBLL.validate(task)
@@ -177,9 +219,9 @@ def enqueue_task(
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
@@ -212,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
def delete_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
@@ -221,8 +263,9 @@ def delete_task(
status_reason: str,
delete_external_artifacts: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if (
@@ -244,6 +287,7 @@ def delete_task(
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
@@ -262,6 +306,7 @@ def delete_task(
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save()
else:
task.delete()
@@ -273,15 +318,16 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force and task.status == TaskStatus.published:
@@ -296,6 +342,8 @@ def reset_task(
# dequeue may fail if the task was not enqueued
pass
TaskBLL.remove_task_from_all_queues(company_id=company_id, task_id=task.id)
cleaned_up = cleanup_task(
company=company_id,
user=user_id,
@@ -318,11 +366,17 @@ def reset_task(
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
unset__started=1,
unset__completed=1,
unset__published=1,
unset__active_duration=1,
unset__enqueue_status=1,
)
if clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
@@ -343,11 +397,6 @@ def reset_task(
status_message="reset",
user_id=user_id,
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
enqueue_status=None,
**updates,
)
@@ -357,14 +406,15 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
publish_model_func: Callable[[str, str, str], Any] = None,
publish_model_func: Callable[[str, str, Identity], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -387,7 +437,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id, user_id)
publish_model_func(model.id, company_id, identity)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@@ -411,7 +461,7 @@ def publish_task(
def stop_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
user_name: str,
status_reason: str,
force: bool,
@@ -424,10 +474,11 @@ def stop_task(
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
@@ -437,7 +488,6 @@ def stop_task(
"last_update",
"execution.queue",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:

View File

@@ -1,14 +1,18 @@
from datetime import datetime
from typing import Sequence, Union
from typing import Sequence
import attr
import six
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -156,25 +160,78 @@ def get_possible_status_changes(current_status):
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
def get_many_tasks_for_writing(
company_id: str,
identity: Identity,
query: Q = None,
only: Sequence = None,
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company", ) if f not in only]
if missing:
only = [*only, *missing]
if isinstance(project_ids, str):
project_ids = [project_ids]
result = list(
Task.get_many(
company=company_id,
query=query,
override_projection=only,
allow_public=True,
return_dicts=False,
)
)
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
if not company_id:
return result
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:
raise errors.forbidden.NoWritePermission(
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
)
result = [task for task in result if task.id not in forbidden_tasks]
return result
def get_task_with_write_access(
task_id: str,
company_id: str,
identity: Identity,
only=None,
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(_only=only, **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str,
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
only=("id", "status"),
identity=identity,
)
if allow_all_statuses:
return task
@@ -189,9 +246,88 @@ def get_task_for_update(
return task
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
def update_task(
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
):
now = datetime.utcnow()
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)
def get_last_metric_updates(
task_id: str,
last_scalar_events: dict,
raw_updates: dict,
extra_updates: dict,
model_events: bool = False,
):
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values - to_add}__exists"] = True
db_cls = Model if model_events else Task
task = db_cls.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
if max_values:
if len(total_metrics) >= max_values and metric not in total_metrics:
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics

View File

@@ -1,76 +1,24 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from typing import (
Optional,
Callable,
Dict,
Any,
Set,
Iterable,
Tuple,
Sequence,
TypeVar,
Union,
)
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model import AttributedDocument
from apiserver.database.model.project import Project
from apiserver.database.model.settings import Settings
class SetFieldsResolver:
"""
The class receives set fields dictionary
and for the set fields that require 'min' or 'max'
operation replace them with a simple set in case the
DB document does not have these fields set
"""
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = {}
self.fields = {}
self.add_fields(**set_fields)
def add_fields(self, **set_fields: Any):
self.orig_fields.update(set_fields)
self.fields.update(
{
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
return self.fields[name]
return name
def get_fields(self, doc: AttributedDocument):
"""
For the given document return the set fields instructions
with min/max operations replaced with a single set in case
the document does not have the field set
"""
return {
self._get_updated_name(doc, name): value
for name, value in self.orig_fields.items()
}
def get_names(self) -> Set[str]:
"""
Returns the names of the fields that had min/max modifiers
in the format suitable for projection (dot separated)
"""
return set(name.replace("__", ".") for name in self.fields.values())
@functools.lru_cache()
def get_server_uuid() -> Optional[str]:
return Settings.get_by_key("server.uuid")
@@ -132,3 +80,13 @@ def run_batch_operation(
}
)
return results, failures
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())

View File

@@ -5,13 +5,13 @@ from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from boltons.iterutils import partition, chunked_iter
from pyhocon import ConfigTree
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
from apiserver.apierrors.errors import bad_request, server_error
from apiserver.apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry,
WorkerEntry,
StatusReportRequest,
@@ -30,12 +30,14 @@ from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@@ -68,7 +70,7 @@ class WorkerBLL:
"""
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
queues = queues or []
with translate_errors_context():
@@ -141,8 +143,6 @@ class WorkerBLL:
try:
entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
@@ -150,15 +150,16 @@ class WorkerBLL:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
self.log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=entry.key,
worker_id=report.worker,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
)
now = datetime.utcnow()
entry.last_activity_time = now
entry.queue = report.queue
if report.queues:
@@ -175,6 +176,7 @@ class WorkerBLL:
last_worker_report=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)
@@ -200,6 +202,24 @@ class WorkerBLL:
finally:
self._save_worker(entry)
def get_count(
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
):
if not last_seen:
return len(
self._get_keys(company_id, user_tags=tags, system_tags=system_tags)
)
return len(
self.get_all(
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags
)
)
def get_all(
self,
company_id: str,
@@ -235,18 +255,15 @@ class WorkerBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
for entry in self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
)
)
]
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set(
@@ -265,9 +282,7 @@ class WorkerBLL:
}
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
task_ids = task_ids.union(
filter(
None,
@@ -396,15 +411,16 @@ class WorkerBLL:
msg = "Failed saving worker entry"
log.exception(msg)
def _get(
def _get_keys(
self,
company: str,
user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
) -> Sequence[bytes]:
if not (user_tags or system_tags):
match = self._get_worker_key(company, user, "*")
return list(self.redis.scan_iter(match))
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
@@ -412,67 +428,79 @@ class WorkerBLL:
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags:
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(
all_workers_key, min=0, max=timestamp
)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
return list(worker_keys)
def _get(
self,
company: str,
user: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
for key in worker_keys:
data = self.redis.get(key)
for keys in chunked_iter(
self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
),
1000,
):
data = self.redis.mget(keys)
if data:
entries.append(WorkerEntry.from_json(data))
entries.extend(WorkerEntry.from_json(d) for d in data if d)
return entries
@@ -481,18 +509,17 @@ class WorkerBLL:
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es(
def log_stats_to_es(
self,
company_id: str,
company_name: str,
worker: str,
worker_id: str,
timestamp: int,
task: str,
machine_stats: MachineStats,
) -> bool:
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise
:return: The amount of logged documents
"""
es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
@@ -504,8 +531,7 @@ class WorkerBLL:
_index=es_index,
_source=dict(
timestamp=timestamp,
worker=worker,
company=company_name,
worker=worker_id,
task=task,
category=category,
metric=metric,
@@ -530,7 +556,7 @@ class WorkerBLL:
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return added
@attr.s(auto_attribs=True)

View File

@@ -215,6 +215,10 @@ class WorkerStats:
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}

View File

@@ -72,6 +72,7 @@
httponly: true # allow only http to access the cookies (no JS etc)
secure: false # not using HTTPS
domain: null # Limit to localhost is not supported
samesite: Lax
max_age: 99999999999
}

View File

@@ -23,6 +23,11 @@
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
revoke_in_fixed_mode: true
}
services_agent {
role: "admin"
user_key: "P4BMJA7RK3TKBXGSY8OAA1FA8TOD11"
user_secret: "9LsgSfa0SYz0zli1_c500ZcLqanre2xkWOpepyt1w-BKK3_DKPHrtoj3JSHvyy8bIi0"
}
tests {
role: "user"
display_name: "Default User"

View File

@@ -1,4 +1,4 @@
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
# if set to true then on task delete/reset external file urls for known storage types are scheduled for async delete
# otherwise they are returned to a client for the client side delete
enabled: true
max_retries: 3

View File

@@ -32,6 +32,8 @@ events_retrieval {
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
multi_plots_batch_size: 1000
}
# if set then plot str will be checked for the valid json on plot add

View File

@@ -1,7 +1,4 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -1,3 +1,9 @@
tags_cache {
expiration_seconds: 3600
}
download {
redis_timeout_sec: 300
batch_size: 500
max_download_items: 50000
max_project_name_length: 60
}

View File

@@ -18,8 +18,9 @@ aws {
{
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
host: "localhost:9000"
key: "evg_user"
secret: "evg_pass"
key: "minioadmin"
secret: "minioadmin"
# region: my-server
multipart: false
secure: false
}

View File

@@ -11,9 +11,6 @@ non_responsive_tasks_watchdog {
multi_task_histogram_limit: 100
hyperparam_values {
# maximal amount of distinct hyperparam values to retrieve
max_count: 100
# max allowed outdate time for the cashed result
cache_allowed_outdate_sec: 60
@@ -26,4 +23,6 @@ hyperparam_values {
max_last_metrics: 2000
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
async_events_delete: false
async_events_delete: true
# do not use async_delete if the deleted task has amount of events lower than this threshold
async_events_delete_threshold: 100000

View File

@@ -16,7 +16,7 @@ from mongoengine.errors import (
LookUpError,
InvalidQueryError,
)
from pymongo.errors import PyMongoError, NotMasterError
from pymongo.errors import PyMongoError, NotPrimaryError
from apiserver.apierrors import errors
@@ -198,7 +198,7 @@ def translate_errors_context(message=None, **kwargs):
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
except PyMongoError as e:
raise errors.server_error.InternalError(message, err=str(e))
except NotMasterError as e:
except NotPrimaryError as e:
raise errors.server_error.InternalError(message, err=str(e))
except MakeGetAllQueryError as e:
raise errors.bad_request.ValidationError(e.error, field=e.field)

View File

@@ -1,5 +1,6 @@
import re
from collections import namedtuple
from collections import namedtuple, defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
Collection,
@@ -11,10 +12,10 @@ from typing import (
Mapping,
Any,
Callable,
Dict,
List,
Generator,
)
import attr
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
@@ -22,6 +23,7 @@ from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors, APIError
from apiserver.apierrors.base import BaseError
from apiserver.apierrors.errors.bad_request import FieldsValueError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
@@ -132,94 +134,134 @@ class GetMixin(PropsMixin):
self.range_fields = range_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
class NewListFieldBucketHelper:
op_prefix = "__$"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
_unary_operators = {
"__$not": False,
}
_reset_operator = "__$nop"
_operators = {
"__$all": Q.AND,
"__$and": Q.AND,
"__$any": Q.OR,
"__$or": Q.OR,
}
default_global_operator = Q.AND
default_context = Q.OR
# not_all modifier currently not supported due to the backwards compatibility
mongo_modifiers = {
Q.AND: {True: "all", False: "nin"},
Q.OR: {True: "in", False: "nin"},
}
def __init__(self, field, legacy=False):
@attr.s(auto_attribs=True)
class Term:
operator: str = None
reset: bool = False
include: bool = True
value: str = None
def __init__(self, field: str, data: Sequence[str], legacy=False):
self._field = field
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
self.global_operator = None
self.actions = defaultdict(list)
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
try:
op = (
v[len(self.op_prefix) :]
if v and v.startswith(self.op_prefix)
else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
except AttributeError:
raise errors.bad_request.FieldsValueError(
"invalid value type, string expected",
field=self._field,
value=str(v),
)
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self.allow_empty = True
return None
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
return self.default_mongo_op
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
self._support_legacy = legacy
current_context = self.default_context
for d in self._get_next_term(data):
if d.operator is not None:
current_context = d.operator
self._support_legacy = False
if self.global_operator is None:
self.global_operator = d.operator
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
if self.global_operator is None:
self.global_operator = self.default_global_operator
if d.reset:
current_context = self.default_context
self._support_legacy = legacy
continue
if d.value is None:
self.allow_empty = True
continue
self.actions[self.mongo_modifiers[current_context][d.include]].append(
d.value
)
if self.global_operator is None:
self.global_operator = self.default_global_operator
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
for value in data:
if value is None:
unary_operator = None
yield self.Term()
continue
if not isinstance(value, str):
raise FieldsValueError(
"invalid value type, string expected",
field=self._field,
value=str(value),
)
if value == self._reset_operator:
unary_operator = None
yield self.Term(reset=True)
continue
if value.startswith(self.op_prefix):
if unary_operator:
raise FieldsValueError(
"Value is expected after",
field=self._field,
operator=unary_operator,
)
if value in self._unary_operators:
unary_operator = value
continue
operator = self._operators.get(value)
if operator is None:
raise FieldsValueError(
"Unsupported operator",
field=self._field,
operator=value,
)
yield self.Term(operator=operator)
continue
if (
not unary_operator
and self._support_legacy
and value.startswith("-")
):
value = value[1:]
if not value:
raise FieldsValueError(
"Missing value after the exclude prefix -",
field=self._field,
value=value,
)
yield self.Term(value=value, include=False)
continue
term = self.Term(value=value)
if unary_operator:
term.include = self._unary_operators[unary_operator]
unary_operator = None
yield term
if unary_operator:
raise FieldsValueError(
"Value is expected after", operator=unary_operator
)
get_all_query_options = QueryParameterOptions()
@@ -364,12 +406,25 @@ class GetMixin(PropsMixin):
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
filters = parameters.pop("filters", {})
if not isinstance(filters, dict):
raise FieldsValueError(
"invalid value type, string expected",
field=filters,
value=str(filters),
)
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=filters
).items():
query &= cls.get_list_filter_query(field, data)
parameters.pop(field, None)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
@@ -493,6 +548,149 @@ class GetMixin(PropsMixin):
return q
@attr.s(auto_attribs=True)
class ListQueryFilter:
"""
Deserialize filters data and build db_query object that represents it with the corresponding
mongo engine operations
Each part has include and exclude lists that map to mongoengine operations as following:
"any"
- include -> 'in'
- exclude -> 'not_all'
- combined by 'or' operation
"all"
- include -> 'all'
- exclude -> 'nin'
- combined by 'and' operation
"op" optional parameter for combining "and" and "all" parts. Can be "and" or "or". The default is "and"
"""
_and_op = "and"
_or_op = "or"
_allowed_op = [_and_op, _or_op]
_db_modifiers: Mapping = {
(Q.OR, True): "in",
(Q.OR, False): "not__all",
(Q.AND, True): "all",
(Q.AND, False): "nin",
}
@attr.s(auto_attribs=True)
class ListFilter:
include: Sequence[str] = []
exclude: Sequence[str] = []
@classmethod
def from_dict(cls, d: Mapping):
if d is None:
return None
return cls(**d)
any: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
all: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
op: str = attr.ib(default="and")
db_query: dict = attr.ib(init=False)
# noinspection PyUnresolvedReferences
@op.validator
def op_validator(self, _, value):
if value not in self._allowed_op:
raise ValueError(
f"Invalid list query filter operator: {value}. "
f"Should be one of {str(self._allowed_op)}"
)
@property
def and_op(self) -> bool:
return self.op == self._and_op
def __attrs_post_init__(self):
self.db_query = {}
for op, conditions in ((Q.OR, self.any), (Q.AND, self.all)):
if not conditions:
continue
operations = {}
for vals, include in (
(conditions.include, True),
(conditions.exclude, False),
):
if not vals:
continue
unique = set(vals)
if None in unique:
# noinspection PyTypeChecker
unique.remove(None)
if include:
operations["size"] = 0
else:
operations["not__size"] = 0
if not unique:
continue
operations[self._db_modifiers[(op, include)]] = list(unique)
self.db_query[op] = operations
@classmethod
def from_data(cls, field, data: Mapping):
if not isinstance(data, dict):
raise errors.bad_request.ValidationError(
"invalid filter for field, dictionary expected",
field=field,
value=str(data),
)
try:
return cls(**data)
except Exception as ex:
raise errors.bad_request.ValidationError(
field=field,
value=str(ex),
)
@classmethod
def get_list_filter_query(
cls, field: str, data: Mapping
) -> Union[RegexQ, RegexQCombination]:
if not data:
return RegexQ()
filter_ = cls.ListQueryFilter.from_data(field, data)
mongoengine_field = field.replace(".", "__")
queries = []
for op, actions in filter_.db_query.items():
if not actions:
continue
ops = []
for action, vals in actions.items():
# cannot just check vals here since 0 is acceptable value
if vals is None or vals == []:
continue
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
if not ops:
continue
if len(ops) == 1:
queries.extend(ops)
continue
queries.append(RegexQCombination(operation=op, children=ops))
if not queries:
return RegexQ()
if len(queries) == 1:
return queries[0]
operation = Q.AND if filter_.and_op else Q.OR
return RegexQCombination(operation=operation, children=queries)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
@@ -507,15 +705,15 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)):
data = [data]
helper = cls.ListFieldBucketHelper(field, legacy=True)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
helper = cls.NewListFieldBucketHelper(field, data=data, legacy=True)
global_op = helper.global_operator
actions = helper.actions
mongoengine_field = field.replace(".", "__")
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
RegexQ(**{f"{mongoengine_field}__{action}": list(set(values))})
for action, values in actions.items()
]
if not queries:
@@ -601,7 +799,7 @@ class GetMixin(PropsMixin):
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
"""Extract a projection list from the provided dictionary. Supports an override projection."""
if override_projection is not None:
return override_projection
if not parameters:
@@ -615,7 +813,8 @@ class GetMixin(PropsMixin):
"""Return include and exclude lists based on passed projection and class definition"""
if projection:
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
projection,
key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
else:
include, exclude = [], []
@@ -754,7 +953,9 @@ class GetMixin(PropsMixin):
@classmethod
def _get_collation_override(cls, field: str) -> Optional[dict]:
return first(
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
v
for k, v in cls._field_collation_overrides.items()
if field.startswith(k) or field.startswith(f"-{k}")
)
@classmethod
@@ -860,7 +1061,9 @@ class GetMixin(PropsMixin):
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
query_dict=query_dict,
data_getter=data_getter,
ret_params=ret_params,
)
return cls._get_many_no_company(
@@ -873,7 +1076,9 @@ class GetMixin(PropsMixin):
@classmethod
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
cls,
query: Q = None,
projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
@@ -1091,21 +1296,6 @@ class GetMixin(PropsMixin):
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):
@@ -1166,7 +1356,7 @@ class UpdateMixin(object):
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
"""Provide convenience methods for a subclass of mongoengine.Document"""
@classmethod
def aggregate(
@@ -1194,25 +1384,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
def set_public(
cls: Type[Document],
company_id: str,
user_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, set__company="")
update: dict = dict(set__company_origin=company_id, set__company="")
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
update: dict = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
if hasattr(cls, "last_change"):
update["set__last_change"] = datetime.utcnow()
if hasattr(cls, "last_changed_by"):
update["set__last_changed_by"] = user_id
return {"updated": cls.objects(id__in=ids).update(**update)}

View File

@@ -3,6 +3,8 @@ from mongoengine import (
DateTimeField,
BooleanField,
EmbeddedDocumentField,
IntField,
ListField,
)
from apiserver.database import Database, strict
@@ -17,12 +19,14 @@ from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.project import Project
from apiserver.database.model.task.metrics import MetricEvent
from apiserver.database.model.task.task import Task
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
}
meta = {
@@ -67,6 +71,7 @@ class Model(AttributedDocument):
"parent",
"metadata.*",
),
range_fields=("last_metrics.*", "last_iteration"),
datetime_fields=("last_update",),
)
@@ -85,6 +90,8 @@ class Model(AttributedDocument):
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
last_change = DateTimeField()
last_changed_by = StringField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
@@ -92,6 +99,9 @@ class Model(AttributedDocument):
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)
last_iteration = IntField(default=0)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@@ -230,11 +230,12 @@ class Task(AttributedDocument):
"project",
"parent",
"hyperparams.*",
"execution.queue",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment", "report"),
fields=("execution.queue", "runtime.*", "models.input.model"),
fields=("runtime.*", "models.input.model"),
)
id = StringField(primary_key=True)
@@ -271,7 +272,7 @@ class Task(AttributedDocument):
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
duration = IntField() # obsolete, do not use
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)

View File

@@ -3,10 +3,14 @@ from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
from typing import Sequence, Dict, Callable
from boltons import iterutils
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.props import PropsMixin
SEP = "."
max_items_per_fetch = config.get("services._mongo.max_page_size", 500)
class _ReferenceProxy(dict):
@@ -278,10 +282,11 @@ class ProjectionHelper(object):
doc_only = list(filter(None, data["only"]))
doc_only = list({"id"} | set(doc_only)) if doc_only else None
for res in projection_func(
doc_type=doc_type, projection=doc_only, ids=ids
):
self._proxy_manager.update(res)
for ids_chunk in iterutils.chunked_iter(ids, max_items_per_fetch):
for res in projection_func(
doc_type=doc_type, projection=doc_only, ids=ids_chunk
):
self._proxy_manager.update(res)
if len(ref_projection) == 1:
do_projection(items[0])

View File

@@ -0,0 +1,19 @@
### Supported api versions
| Release | ApiVersion |
|---------|------------|
| v1.13 | 2.27 |
| v1.12 | 2.26 |
| v1.11 | 2.25 |
| v1.10 | 2.24 |
| v1.9 | 2.23 |
| v1.8 | 2.22 |
| v1.7 | 2.21 |
| v1.6 | 2.20 |
| v1.5 | 2.19 |
| v1.4 | 2.18 |
| v1.3 | 2.17 |
| v1.2 | 2.16 |
| v1.1 | 2.15 |
| v1.0 | 2.14 |
| v0.17 | 2.13 |

View File

@@ -21,7 +21,7 @@ def apply_mappings_to_cluster(
with f.open() as json_data:
data = json.load(json_data)
template_name = f.stem
res = es.indices.put_template(template_name, body=data)
res = es.indices.put_template(name=template_name, body=data)
return {"mapping": template_name, "result": res}
p = HERE / "mappings"

View File

@@ -85,7 +85,7 @@ def check_elastic_empty() -> bool:
es = Elasticsearch(
hosts=cluster_conf.get("hosts", None),
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {})
**cluster_conf.get("args", {}),
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
@@ -115,5 +115,7 @@ def init_es_data():
args = cluster_conf.get("args", {})
http_auth = es_factory.get_credentials(name)
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
res = apply_mappings_to_cluster(
hosts_config, name, es_args=args, http_auth=http_auth
)
log.info(res)

View File

@@ -1,8 +1,8 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {

View File

@@ -1,8 +1,8 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
@@ -20,6 +20,9 @@
},
"queue_length": {
"type": "integer"
},
"company_id": {
"type": "keyword"
}
}
}

View File

@@ -1,8 +1,8 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
"number_of_replicas": 0,
"number_of_shards": 1
},
"mappings": {
"_source": {
@@ -32,6 +32,9 @@
},
"task": {
"type": "keyword"
},
"company_id": {
"type": "keyword"
}
}
}

View File

@@ -450,6 +450,7 @@ class AWSStorage(Storage):
else None,
"use_ssl": cfg.secure,
"verify": cfg.verify,
"region_name": cfg.region or None,
}
name = base[len(scheme_prefix(self.scheme)) :]
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name

View File

@@ -18,9 +18,15 @@ _migration_dir = _parent_dir / _migrations
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names() for alias in utils.get_options(Database)
)
for alias in utils.get_options(Database):
collection_names = get_db(alias).list_collection_names()
if collection_names and any(
name in collection_names
for name in ["company", "user", "versions"]
):
return False
return True
def get_last_server_version() -> Version:

View File

@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -66,6 +68,7 @@ class PrePopulate:
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
users_filename = "users.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
img_source_regex = re.compile(
@@ -78,6 +81,7 @@ class PrePopulate:
project_cls: Type[Project]
model_cls: Type[Model]
user_cls: Type[User]
auth_user_cls: Type[AuthUser]
# noinspection PyTypeChecker
@classmethod
@@ -90,6 +94,8 @@ class PrePopulate:
cls.project_cls = cls._get_entity_type("database.model.project.Project")
if not hasattr(cls, "user_cls"):
cls.user_cls = cls._get_entity_type("database.model.User")
if not hasattr(cls, "auth_user_cls"):
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
@@ -205,6 +211,8 @@ class PrePopulate:
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
export_events: bool = True,
export_users: bool = False,
) -> Sequence[str]:
cls._init_entity_types()
@@ -240,11 +248,15 @@ class PrePopulate:
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
if export_users:
cls._export_users(zfile)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
export_events=export_events,
cleanup_users=not export_users,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
@@ -265,6 +277,9 @@ class PrePopulate:
metadata_hash=metadata_hash,
)
if created_files:
print("Created files:\n" + "\n".join(file for file in created_files))
return created_files
@classmethod
@@ -296,18 +311,26 @@ class PrePopulate:
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
user_mapping = cls._import_users(zfile, company_id)
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)
cls._import(
zfile,
company_id=company_id,
user_id=user_id,
metadata=metadata,
user_mapping=user_mapping,
)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
@@ -438,7 +461,7 @@ class PrePopulate:
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
entities = defaultdict(set)
entities: Dict[Any] = defaultdict(set)
if projects:
print("Reading projects...")
@@ -497,7 +520,6 @@ class PrePopulate:
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
@classmethod
@@ -505,7 +527,6 @@ class PrePopulate:
task.comment = "Auto generated by Allegro.ai"
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
@@ -513,17 +534,32 @@ class PrePopulate:
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
def _cleanup_auth_user(cls, user: AuthUser):
user.company = ""
for cred in user.credentials:
if getattr(cred, "company", None):
cred["company"] = ""
return user
@classmethod
def _cleanup_be_user(cls, user: User):
user.company = ""
user.preferences = None
return user
@classmethod
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
if cleanup_users:
entity.user = ""
if entity_cls == cls.task_cls:
cls._cleanup_task(entity)
elif entity_cls == cls.model_cls:
cls._cleanup_model(entity)
elif entity == cls.project_cls:
elif entity_cls == cls.project_cls:
cls._cleanup_project(entity)
@classmethod
@@ -633,6 +669,38 @@ class PrePopulate:
else:
print(f"Artifact {full_path} not found")
@classmethod
def _export_users(cls, writer: ZipFile):
auth_users = {
user.id: cls._cleanup_auth_user(user)
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
}
if not auth_users:
return
be_users = {
user.id: cls._cleanup_be_user(user)
for user in cls.user_cls.objects(id__in=list(auth_users))
}
if not be_users:
return
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
print(f"Writing {len(auth_users)} users into {writer.filename}")
data = {}
for field, users in (("auth", auth_users), ("backend", be_users)):
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for user in users.values():
w.write(user.to_json())
data[field] = f.getvalue()
def get_field_bytes(k: str, v: bytes) -> bytes:
return f'"{k}": '.encode("utf-8") + v
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
@classmethod
def _get_base_filename(cls, cls_: type):
name = f"{cls_.__module__}.{cls_.__name__}"
@@ -642,7 +710,13 @@ class PrePopulate:
@classmethod
def _export(
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
cls,
writer: ZipFile,
entities: dict,
hash_,
tag_entities: bool = False,
export_events: bool = True,
cleanup_users: bool = True,
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
@@ -656,18 +730,19 @@ class PrePopulate:
if not items:
continue
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
if export_events:
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
@@ -717,7 +792,10 @@ class PrePopulate:
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
cls,
reader: ZipFile,
entity_files: Sequence,
metadata: Mapping[str, Any],
) -> Mapping[str, str]:
if not metadata or not any(
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
@@ -745,6 +823,68 @@ class PrePopulate:
)
return ids
@classmethod
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
"""
Import users to db and return the mapping of old user ids to the new ones
If no users were in the users file then the mapping was empty
If the user in the file has the same email as one of the existing ones then this user is skipped
and its id is mapped to the existing user with the same email
If the user with the same id exists in backend or auth db then its creation is skipped
"""
users_file = first(
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
)
if not users_file:
return {}
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
cls.auth_user_cls.objects().scalar("id")
)
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
user_id_mappings = {}
with reader.open(users_file) as f:
data = json.loads(f.read())
auth_users = {u["_id"]: u for u in data["auth"]}
be_users = {u["_id"]: u for u in data["backend"]}
for uid, user in auth_users.items():
email = user.get("email")
existing_user_id = existing_user_emails.get(email)
if existing_user_id:
user_id_mappings[uid] = existing_user_id
continue
user_id_mappings[uid] = uid
if uid in existing_user_ids:
continue
credentials = user.get("credentials", [])
for c in credentials:
if c.get("company") == "":
c["company"] = company_id
if hasattr(cls.auth_user_cls, "sec_groups"):
user_role = user.get("role", Role.user)
if user_role == Role.user:
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
else:
user["sec_groups"] = [
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
"30795571-a470-4717-a80d-e8705fc776bf",
"30795571a4704717a80de8705897ytuyg",
]
auth_user = cls.auth_user_cls.from_json(json.dumps(user), created=True)
auth_user.company = company_id
auth_user.save()
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]), created=True)
be_user.company = company_id
be_user.save()
return user_id_mappings
@classmethod
def _import(
cls,
@@ -753,6 +893,7 @@ class PrePopulate:
user_id: str = None,
metadata: Mapping[str, Any] = None,
sort_tasks_by_last_updated: bool = True,
user_mapping: Mapping[str, str] = None,
):
"""
Import entities and events from the zip file
@@ -763,7 +904,7 @@ class PrePopulate:
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
]
metadata = metadata or {}
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
@@ -773,7 +914,13 @@ class PrePopulate:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids
f,
full_name=full_name,
company_id=company_id,
user_id=user_id,
metadata=metadata,
old_to_new_ids=old_to_new_ids,
user_mapping=user_mapping,
)
if res:
tasks = res
@@ -794,7 +941,7 @@ class PrePopulate:
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, company_id, user_id, task.id)
cls._import_events(f, company_id, task.user, task.id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@@ -874,7 +1021,7 @@ class PrePopulate:
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
name = TaskModelNames[type_]
if old_model and not any(
m
@@ -908,7 +1055,9 @@ class PrePopulate:
user_id: str,
metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None,
user_mapping: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]:
user_mapping = user_mapping or {}
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
@@ -930,7 +1079,7 @@ class PrePopulate:
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, cls.project_cls):
@@ -960,13 +1109,17 @@ class PrePopulate:
return tasks
@classmethod
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str):
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
for ev in events:
ev["task"] = task_id
ev["company_id"] = company_id
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked=True
company_id=company_id,
identity=Identity(user_id, company=company_id, role=Role.admin),
events=events,
worker="",
)

View File

@@ -2,7 +2,7 @@ from os import getenv
from boltons.iterutils import first
from redis import StrictRedis
from rediscluster import RedisCluster
from redis.cluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config
@@ -83,7 +83,7 @@ class RedisManager(object):
def host(self, alias):
r = self.connection(alias)
if isinstance(r, RedisCluster):
connections = first(r.connection_pool._available_connections.values())
connections = r.get_default_node().redis_connection.connection_pool._available_connections
else:
connections = r.connection_pool._available_connections

View File

@@ -1,38 +1,37 @@
attrs>=22.1.0
attrs>=22.1.0,<23
azure-storage-blob>=12.13.1
bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
boto3-stubs[s3]>=1.24.35
clearml>=1.6.0,<1.7.0
boto3>=1.26
boto3-stubs[s3]>=1.26
clearml>=1.10.3
dpath>=1.4.2,<2.0
elasticsearch==7.13.3
elasticsearch==7.17.9
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
flask>=0.12.2
funcsigs==1.0.2
flask>=2.3.3
furl>=2.0.0
google-cloud-storage==2.0.0
protobuf==3.19.5
gunicorn>=19.7.1
humanfriendly==4.18
jinja2==2.11.3
google-cloud-storage>=2.8.0
gunicorn>=20.1.0
humanfriendly>=4.17
jinja2
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.24.2
mongoengine==0.27.0
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyhocon>=0.3.35r
pyjwt>=2.4.0
pymongo[srv]==3.12.0
pymongo==4.4.0
python-rapidjson>=0.6.3
redis==4.4.4
redis-py-cluster>=2.1.3
redis>=4.5.4,<5
requests>=2.13.0
semantic_version>=2.8.3,<3
setuptools>=65.5.1
six
tqdm
validators>=0.12.4
validators>=0.12.4
urllib3>=1.26.18
werkzeug>=3.0.1

View File

@@ -1,3 +1,43 @@
field_filter {
type: object
description: Filter on a field that includes combination of 'any' or 'all' included and excluded terms
properties {
any {
type: object
description: All the terms in 'any' condition are combined with 'or' operation
properties {
"include" {
type: array
items {type: string}
}
exclude {
type: array
items {type: string}
}
}
}
all {
type: object
description: All the terms in 'all' condition are combined with 'and' operation
properties {
"include" {
type: array
items {type: string}
}
exclude {
type: array
items {type: string}
}
}
}
op {
type: string
description: The operation between 'any' and 'all' parts of the filter if both are provided
default: and
enum: [and, or]
}
}
}
metadata_item {
type: object
properties {

View File

@@ -103,4 +103,39 @@ plots_response {
items {"$ref": "#/definitions/plots_response_task_metrics"}
}
}
}
single_value_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
task_name {
type: string
description: Task name
}
values {
type: array
items {
type: object
properties {
metric { type: string }
variant { type: string}
value { type: number }
timestamp { type: number }
}
}
}
}
}
single_value_metrics_response {
type: object
properties {
tasks {
description: Single value metrics grouped by task
type: array
items {"$ref": "#/definitions/single_value_task_metrics"}
}
}
}

View File

@@ -414,7 +414,7 @@ task {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
models {
description: "Task models"

View File

@@ -11,8 +11,8 @@ _definitions {
type: number
}
type {
description: "training_stats_vector"
const: "training_stats_scalar"
description: "'training_stats_scalar'"
type: string
}
task {
description: "Task ID (required)"
@@ -46,8 +46,8 @@ _definitions {
type: number
}
type {
description: "training_stats_vector"
const: "training_stats_vector"
description: "'training_stats_vector'"
type: string
}
task {
description: "Task ID (required)"
@@ -82,8 +82,8 @@ _definitions {
type: number
}
type {
description: ""
const: "training_debug_image"
description: "'training_debug_image'"
type: string
}
task {
description: "Task ID (required)"
@@ -123,7 +123,7 @@ _definitions {
}
type {
description: "'plot'"
const: "plot"
type: string
}
task {
description: "Task ID (required)"
@@ -221,7 +221,7 @@ _definitions {
}
type {
description: "'log'"
const: "log"
type: string
}
task {
description: "Task ID (required)"
@@ -754,6 +754,42 @@ get_task_metrics{
}
}
}
get_multi_task_metrics {
"2.28" {
description: """Get unique metrics and variants from the events of the specified type.
Only events reported for the passed task or model ids are analyzed."""
request {
type: object
required: [ tasks ]
properties {
tasks {
description: task ids to get metrics from
type: array
items {type: string}
}
model_events {
description: If not set or set to false then passed ids are task ids otherwise model ids
type: boolean
default: false
}
event_type {
"description": Event type. If not specified then metrics are collected from the reported events of all types
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
get_task_log {
"1.5" {
description: "Get all 'log' events for this task"
@@ -971,10 +1007,17 @@ get_task_events {
}
}
"2.22": ${get_task_events."2.1"} {
request.properties.model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
request.properties {
model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
@@ -1149,6 +1192,20 @@ get_multi_task_plots {
default: false
}
}
"2.26": ${get_multi_task_plots."2.22"} {
request.properties.last_iters_per_task_metric {
type: boolean
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
default: true
}
}
"2.28": ${get_multi_task_plots."2.26"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@@ -1335,6 +1392,13 @@ multi_task_scalar_metrics_iter_histogram {
default: false
}
}
"2.28": ${multi_task_scalar_metrics_iter_histogram."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_single_value_metrics {
"2.20" {
@@ -1353,36 +1417,7 @@ get_task_single_value_metrics {
}
}
}
response {
type: object
properties {
tasks {
description: Single value metrics grouped by task
type: array
items {
type: object
properties {
task {
type: string
description: Task ID
}
values {
type: array
items {
type: object
properties {
metric { type: string }
variant { type: string}
value { type: number }
timestamp { type: number }
}
}
}
}
}
}
}
}
response {"$ref": "#/definitions/single_value_metrics_response"}
}
"2.22": ${get_task_single_value_metrics."2.20"} {
request.properties.model_events {
@@ -1391,6 +1426,13 @@ get_task_single_value_metrics {
default: false
}
}
"2.28": ${get_task_single_value_metrics."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_latest_scalar_values {
"2.1" {
@@ -1492,6 +1534,10 @@ get_scalar_metric_data {
type: string
description: type of metric
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
@@ -1514,7 +1560,7 @@ get_scalar_metric_data {
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
description: "Scroll ID for getting more results"
}
}
}

View File

@@ -6,21 +6,12 @@ _default {
}
supported_modes {
authorize: false
authorize: null
"2.9" {
description: """ Return supported login modes."""
request {
type: object
properties {
state {
description: "ASCII base64 encoded application state"
type: string
}
callback_url_prefix {
description: "URL prefix used to generate the callback URL for each supported SSO provider"
type: string
}
}
additionalProperties: false
}
response {
type: object
@@ -59,7 +50,7 @@ supported_modes {
description: "SSO authentication providers"
type: object
additionalProperties {
desctiprion: "Provider redirect URL"
description: "Provider redirect URL"
type: string
}
}
@@ -95,7 +86,7 @@ supported_modes {
}
logout {
authorize: false
authorize: null
allow_roles = [ "*" ]
"2.13" {
description: """ Logout (including SSO, if used)) """

View File

@@ -1,6 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_common.conf"
include "_tasks_common.conf"
multi_field_pattern_data {
type: object
properties {
@@ -104,6 +104,17 @@ _definitions {
"$ref": "#/definitions/metadata_item"
}
}
last_iteration {
description: "Last iteration reported for this model"
type: integer
}
last_metrics {
description: "Last metric variants (hash to events), one for each metric hash"
type: object
additionalProperties {
"$ref": "#/definitions/last_metrics_variants"
}
}
stats {
description: "Model statistics"
type: object
@@ -250,6 +261,14 @@ get_all_ex {
}
}
}
"2.27": ${get_all_ex."2.23"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.1" {
@@ -346,9 +365,6 @@ get_all {
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
@@ -384,6 +400,17 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.26": ${get_all."2.15"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
}
get_frameworks {
"2.8" {
@@ -1064,4 +1091,38 @@ delete_metadata {
}
}
}
}
update_tags {
"2.27" {
description: Add or remove tags from multiple models
request {
type: object
properties {
ids {
type: array
description: IDs of the models to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated models
}
}
}
}
}

View File

@@ -1,5 +1,38 @@
_description: "This service provides organization level operations"
_definitions {
value_mapping {
type: object
required: [key, value]
properties {
key {
description: Original value
type: object
}
value {
description: Translated value
type: object
}
}
}
field_mapping {
type: object
required: [field]
properties {
field {
description: The source field name as specified in the only_fields
type: string
}
name {
description: The column name in the exported csv file
type: string
}
values {
type: array
items { "$ref": "#/definitions/value_mapping"}
}
}
}
}
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
@@ -170,7 +203,7 @@ get_entities_count {
default: false
}
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
description: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
@@ -197,3 +230,69 @@ get_entities_count {
}
}
}
prepare_download_for_get_all {
"2.26": {
description: Prepares download from get_all_ex parameters
request {
type: object
required: [ entity_type, only_fields, field_mappings]
properties {
only_fields {
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items {type: string}
}
allow_public {
description: "Allow public entities to be returned in the results"
type: boolean
default: true
}
search_hidden {
description: "If set to 'true' then hidden entities are included in the search results"
type: boolean
default: false
}
entity_type {
description: "The type of the entity to retrieve"
type: string
enum: [
task
model
]
}
field_mappings {
description: The name and value mappings for the exported fields. The fields that are not in the mappings will not be exported
type: array
items { "$ref": "#/definitions/field_mapping"}
}
}
}
response {
type: object
properties {
prepare_id {
description: "Prepare ID (use when calling 'download_for_get_all')"
type: string
}
}
}
}
}
download_for_get_all {
"2.26": {
description: Generates a file for the download
request {
type: object
required: [ prepare_id ]
properties {
prepare_id {
description: "Call ID returned by a call to prepare_download_for_get_all"
type: string
}
}
}
response {
type: string
}
}
}

View File

@@ -1,7 +1,42 @@
_description: "Provides a management API for pipelines in the system."
_definitions {
include "_common.conf"
}
delete_runs {
"2.26": ${_definitions.batch_operation} {
description: Delete pipeline runs
request {
required: [ids, project]
properties {
ids.description: "IDs of the pipeline runs to delete. Should be the ids of pipeline controller tasks"
project {
description: "Pipeline project ids. When deleting at least one run should be left"
type: string
}
}
}
response {
properties {
succeeded.items.properties.deleted {
description: "Indicates whether the task was deleted"
type: boolean
}
succeeded.items.properties.updated_children {
description: "Number of child tasks whose parent property was updated"
type: integer
}
succeeded.items.properties.updated_models {
description: "Number of models whose task property was updated"
type: integer
}
succeeded.items.properties.deleted_models {
description: "Number of deleted output models"
type: integer
}
}
}
}
}
start_pipeline {
"2.17" {
description: "Start a pipeline"
@@ -24,7 +59,7 @@ start_pipeline {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
value: { type: string }
}
}
}
@@ -44,4 +79,15 @@ start_pipeline {
}
}
}
"2.28": ${start_pipeline."2.17"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}

View File

@@ -1,5 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
@@ -569,7 +570,7 @@ get_all_ex {
request {
properties {
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
description: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
@@ -653,6 +654,22 @@ get_all_ex {
enum: [pipeline, report, dataset]
}
}
"2.25": ${get_all_ex."2.24"} {
request.properties.children_tags {
description: "The list of tag values to filter children by. Takes effect only if children_type is set. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
"2.27": ${get_all_ex."2.25"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
children_tags_filter: ${_definitions.field_filter}
}
}
}
update {
"2.1" {
@@ -801,6 +818,26 @@ validate_delete {
}
}
}
"2.26": ${validate_delete."2.14"} {
response.properties {
reports {
description: "The total number of reports under the project and all its children"
type: integer
}
non_archived_reports {
description: "The total number of non-archived reports under the project and all its children"
type: integer
}
pipelines {
description: "The total number of pipelines with active controllers under the project and all its children"
type: integer
}
datasets {
description: "The total number of non-empty datasets under the project and all its children"
type: integer
}
}
}
}
delete {
"2.1" {
@@ -861,6 +898,13 @@ delete {
}
}
}
"2.26": ${delete."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the project tasks and models from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
get_unique_metric_variants {
"2.1" {
@@ -898,6 +942,20 @@ get_unique_metric_variants {
}
}
}
"2.25": ${get_unique_metric_variants."2.13"} {
request.properties.model_metrics {
description: If set to true then bring unique metric and variant names from the project models otherwise from the project tasks
type: boolean
default: false
}
}
"2.28": ${get_unique_metric_variants."2.25"} {
request.properties.ids {
description: IDs of the tasks or models to get metrics from
type: array
items {type: string}
}
}
}
get_hyperparam_values {
"2.13" {
@@ -945,6 +1003,26 @@ get_hyperparam_values {
}
}
}
"2.26": ${get_hyperparam_values."2.13"} {
request.properties {
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
"2.27": ${get_hyperparam_values."2.26"} {
request.properties.pattern {
type: string
description: The search pattern regex
}
}
}
get_hyper_parameters {
"2.9" {
@@ -1042,6 +1120,20 @@ get_model_metadata_values {
}
}
}
"2.26": ${get_model_metadata_values."2.17"} {
request.properties {
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
}
get_model_metadata_keys {
"2.17" {
@@ -1201,13 +1293,15 @@ get_task_parents {
}
project {
type: object
id {
description: "The ID of the parent task project"
type: string
}
name {
description: "The name of the parent task project"
type: string
properties {
id {
description: "The ID of the parent task project"
type: string
}
name {
description: "The name of the parent task project"
type: string
}
}
}
}
@@ -1227,4 +1321,58 @@ get_task_parents {
}
}
}
"2.25": ${get_task_parents."2.13"} {
request.properties.task_name {
description: Task name pattern for the returned parent tasks
type: string
}
}
}
get_user_names {
"2.26" {
description: "Get names and ids of the users who created child entitites under the passed projects"
request {
type: object
properties {
projects {
description: "The list of projects. If not passed or empty then all the projects are searched"
type: array
items { type: string }
}
include_subprojects {
description: "If set to 'true' and the projects field is not empty then the result includes user name from the subprojects children"
type: boolean
default: true
}
entity {
description: The type of the child entity to look for
type: string
enum: [task, model]
default: task
}
}
}
response {
type: object
properties {
users {
description: "The list of users sorted by their names"
type: array
items {
type: object
properties {
id {
description: "The ID of the user"
type: string
}
name {
description: "The name of the user"
type: string
}
}
}
}
}
}
}
}

View File

@@ -159,6 +159,14 @@ get_all_ex {
default: false
}
}
"2.27": ${get_all_ex."2.21"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.4" {

View File

@@ -568,6 +568,32 @@ get_task_data {
}
}
}
"2.25": ${get_task_data."2.23"} {
request.properties {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
single_value_metrics {
type: object
description: If passed then task single value metrics are returned
additionalProperties: false
}
}
response.properties.single_value_metrics {
type: array
description: Single value metrics grouped by task
items {"$ref": "#/definitions/single_value_task_metrics"}
}
}
"2.26": ${get_task_data."2.25"} {
request.properties.plots.properties.last_iters_per_task_metric {
type: boolean
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
default: true
}
}
}
get_all_ex {
"2.23" {
@@ -668,9 +694,6 @@ get_all_ex {
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
@@ -687,6 +710,21 @@ get_all_ex {
}
}
}
"2.26": ${get_all_ex."2.23"} {
request.properties.include_subprojects {
description: "If set to 'true' and project field is set then reports from the subprojects are searched too"
type: boolean
default: false
}
}
"2.27": ${get_all_ex."2.26"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_tags {
"2.23" {

View File

@@ -190,6 +190,14 @@ get_all_ex {
}
}
}
"2.27": ${get_all_ex."2.23"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.1" {
@@ -289,9 +297,6 @@ get_all {
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
@@ -334,6 +339,17 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.26": ${get_all."2.15"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
}
get_types {
"2.8" {
@@ -470,7 +486,7 @@ clone {
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
}
}
@@ -648,7 +664,7 @@ create {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
}
}
@@ -737,7 +753,7 @@ validate {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
}
}
@@ -899,7 +915,7 @@ edit {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
runtime {
description: "Task runtime mapping"
@@ -1507,6 +1523,19 @@ dequeue {
}
}
}
"2.25": ${dequeue."1.5"} {
request.properties.remove_from_all_queues {
type: boolean
description: If set to 'true' then the task is searched and removed from all the queues. Otherwise only from the queue stored in the task execution parameters
default: false
}
}
"2.26": ${dequeue."2.25"} {
request.properties.new_status {
type: string
description: The new status to assign to the task after the dequeue instead of the default one
}
}
}
dequeue_many {
"2.13": ${_definitions.change_many_request} {
@@ -1525,6 +1554,19 @@ dequeue_many {
}
}
}
"2.25": ${dequeue_many."2.13"} {
request.properties.remove_from_all_queues {
type: boolean
description: If set to 'true' then the tasks are searched and removed from all the queues. Otherwise only from the queue stored in the task execution parameters
default: false
}
}
"2.26": ${dequeue_many."2.25"} {
request.properties.new_status {
type: string
description: The new status to assign to the task after the dequeue instead of the default one
}
}
}
set_requirements {
"2.1" {
@@ -2013,3 +2055,37 @@ move {
}
}
}
update_tags {
"2.27" {
description: Add or remove tags from multiple tasks
request {
type: object
properties {
ids {
type: array
description: IDs of the tasks to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated tasks
}
}
}
}
}

View File

@@ -155,6 +155,17 @@ get_current_user {
}
}
}
"2.26": ${get_current_user."2.20"} {
response.properties.settings {
type: object
properties {
max_download_items {
type: string
description: The maximum items downloaded for this user in csv file downloads
}
}
}
}
}
get_all_ex {

View File

@@ -311,6 +311,41 @@ get_all {
}
}
}
get_count {
"2.26": {
description: "Returns the number of registered workers."
request {
type: object
properties {
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 0
}
tags {
description: The list of allowed worker tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
system_tags {
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
}
}
response {
type: object
properties {
count {
description: Workers count
type: integer
}
}
}
}
}
register {
"2.4" {
description: "Register a worker in the system. Called by the Worker Daemon."

View File

@@ -46,7 +46,6 @@ class AppSequence:
self._attach_request_handlers(request_handlers)
def _attach_request_handlers(self, request_handlers: RequestHandlers):
self.app.before_first_request(request_handlers.before_app_first_request)
self.app.before_request(request_handlers.before_request)
self.app.after_request(request_handlers.after_request)

View File

@@ -1,3 +1,5 @@
import unicodedata
import urllib.parse
from functools import partial
from flask import request, Response, redirect
@@ -20,9 +22,6 @@ class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
def before_app_first_request(self):
pass
def before_request(self):
if request.method == "OPTIONS":
return "", 200
@@ -43,10 +42,21 @@ class RequestHandlers:
response = redirect(call.result.redirect.url, call.result.redirect.code)
else:
headers = None
disable_cache = False
if call.result.filename:
headers = {
"Content-Disposition": f"attachment; filename={call.result.filename}"
}
# make sure that downloaded files are not cached by the client
disable_cache = True
try:
call.result.filename.encode("ascii")
except UnicodeEncodeError:
simple = unicodedata.normalize("NFKD", call.result.filename)
simple = simple.encode("ascii", "ignore").decode("ascii")
# safe = RFC 5987 attr-char
quoted = urllib.parse.quote(call.result.filename, safe="")
filenames = f"filename={simple}; filename*=UTF-8''{quoted}"
else:
filenames = f"filename={call.result.filename}"
headers = {"Content-Disposition": "attachment; " + filenames}
response = Response(
content,
@@ -54,6 +64,9 @@ class RequestHandlers:
status=call.result.code,
headers=headers,
)
if disable_cache:
response.cache_control.no_store = True
response.cache_control.max_age = 0
if call.result.cookies:
for key, value in call.result.cookies.items():

View File

@@ -655,7 +655,11 @@ class APICall(DataContainer):
}
if self.content_type.lower() == JSON_CONTENT_TYPE:
try:
func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii
func = (
json.dumps
if self._json_flags.pop("ensure_ascii", True)
else json.dumps_notascii
)
res = func(res, **(self._json_flags or {}))
except Exception as ex:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
@@ -685,8 +689,12 @@ class APICall(DataContainer):
cookies=self._result.cookies,
)
def get_redacted_headers(self):
headers = self.headers.copy()
def get_redacted_headers(self, fields=None):
headers = (
{k: v for k, v in self._headers.items() if k in fields}
if fields
else self.headers
)
if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully

View File

@@ -30,24 +30,35 @@ def get_auth_func(auth_type):
raise errors.unauthorized.BadAuthType()
def authorize_token(jwt_token, *_, **__):
def authorize_token(jwt_token, service, action, call):
"""Validate token against service/endpoint and requests data (dicts).
Returns a parsed token object (auth payload)
"""
call_info = {"ip": call.real_ip}
def log_error(msg):
info = ", ".join(f"{k}={v}" for k, v in call_info.items())
log.error(f"{msg} Call info: {info}")
try:
return Token.from_encoded_token(jwt_token)
except jwt.exceptions.InvalidKeyError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken(
"jwt invalid key error", reason=ex.args[0]
)
except jwt.InvalidTokenError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken("invalid jwt token", reason=ex.args[0])
except ValueError as ex:
log.exception("Failed while processing token: %s" % ex.args[0])
log_error(f"Failed while processing token: {str(ex.args[0])}.")
raise errors.unauthorized.InvalidToken(
"failed processing token", reason=ex.args[0]
)
except Exception:
log_error("Failed processing token.")
raise
def authorize_credentials(auth_data, service, action, call):

View File

@@ -90,7 +90,7 @@ class Token(Payload):
return token
except Exception as e:
raise errors.unauthorized.InvalidToken(
"failed parsing token, %s" % e.args[0]
"failed parsing token", reason=e.args[0]
)
@classmethod

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.24")
_max_version = PartialVersion("2.28")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -17,7 +17,7 @@ log = config.logger(__file__)
def validate_data(call: APICall, endpoint: Endpoint):
""" Perform all required call/endpoint validation, update call result appropriately """
try:
# todo: remove vaildate_required_fields once all endpoints have json schema
# todo: remove validate_required_fields once all endpoints have json schema
validate_required_fields(endpoint, call)
# set models. models will be validated automatically
@@ -50,10 +50,17 @@ def validate_role(endpoint, call):
pass
def validate_auth(endpoint, call):
""" Validate authorization for this endpoint and call.
If authentication has occurred, the call is updated with the authentication results.
def validate_auth(endpoint: Endpoint, call: "APICall"):
"""
Validate authorization for this endpoint and call.
If authentication has occurred, the call is updated with the authentication results.
For the endpoints with authorize==False the validation is not performed to improve performance
For the endpoints with authorize==True the validation should pass otherwise exception will be thrown
For the endpoints with authorize==None the validation will be tried, but it does not have to succeed
"""
if endpoint.authorize is not None and not endpoint.authorize:
return
if not call.authorization:
# No auth data. Invalid if we need to authorize and valid otherwise
if endpoint.authorize:
@@ -63,10 +70,9 @@ def validate_auth(endpoint, call):
# prepare arguments for validation
service, _, action = endpoint.name.partition(".")
# If we have auth data, we'll try to validate anyway (just so we'll have auth-based permissions whenever possible,
# even if endpoint did not require authorization)
# noinspection PyBroadException
try:
auth = call.authorization or ""
auth = call.authorization
auth_type, _, auth_data = auth.partition(" ")
authorize_func = get_auth_func(auth_type)
call.auth = authorize_func(auth_data, service, action, call)
@@ -78,7 +84,7 @@ def validate_auth(endpoint, call):
def validate_impersonation(endpoint, call):
""" Validate impersonation headers and set impersonated identity and authorization data accordingly.
:returns True if impersonating, False otherwise
:return: True if impersonating, False otherwise
"""
try:
act_as = call.act_as

View File

@@ -3,4 +3,7 @@ from apiserver.service_repo import APICall, endpoint
@endpoint("debug.ping")
def ping(call: APICall, _, __):
call.result.data = {"msg": "ClearML server"}
res = {"msg": "ClearML server"}
if call.data:
res.update(call.data)
call.result.data = res

View File

@@ -30,6 +30,16 @@ from apiserver.apimodels.events import (
GetVariantSampleRequest,
GetMetricSamplesRequest,
TaskMetric,
MultiTaskPlotsRequest,
MultiTaskMetricsRequest,
LegacyLogEventsRequest,
TaskRequest,
GetMetricsAndVariantsRequest,
ModelRequest,
LegacyMetricEventsRequest,
GetScalarMetricDataRequest,
VectorMetricsIterHistogramRequest,
LegacyMultiTaskEventsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@@ -37,6 +47,7 @@ from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -70,9 +81,11 @@ def _assert_task_or_model_exists(
@endpoint("events.add")
def add(call: APICall, company_id, _):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked=allow_locked
company_id=company_id,
identity=call.identity,
events=[data],
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -84,23 +97,23 @@ def add_batch(call: APICall, company_id, _):
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
allow_locked=events[0].get("allow_locked", False),
company_id=company_id,
identity=call.identity,
events=events,
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log")
def get_task_log_v1_5(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
order = request.order
scroll_id = request.scroll_id
batch_size = request.batch_size
events, scroll_id, total_events = event_bll.scroll_task_events(
task.get_index_company(),
task_id,
@@ -114,17 +127,17 @@ def get_task_log_v1_5(call, company_id, _):
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_log", min_version="1.7")
def get_task_log_v1_7(call, company_id, request: LegacyLogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
order = request.order
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_id = request.scroll_id
batch_size = request.batch_size
scroll_order = "asc" if (from_ == "head") else "desc"
@@ -172,9 +185,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.download_task_log")
def download_task_log(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@@ -252,12 +265,16 @@ def download_task_log(call, company_id, _):
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_vector_metrics_and_variants")
def get_vector_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
@@ -266,12 +283,16 @@ def get_vector_metrics_and_variants(call, company_id, _):
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
@endpoint("events.get_scalar_metrics_and_variants")
def get_scalar_metrics_and_variants(
call, company_id, request: GetMetricsAndVariantsRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
@@ -283,16 +304,19 @@ def get_scalar_metrics_and_variants(call, company_id, _):
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint(
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
def vector_metrics_iter_histogram(
call, company_id, request: VectorMetricsIterHistogramRequest
):
task_id = request.task
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
metric = request.metric
variant = request.variant
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task_or_model.get_index_company(), task_id, metric, variant
)
@@ -322,7 +346,9 @@ def make_response(
def get_task_events(_, company_id, request: TaskEventsRequest):
task_id = request.task
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events,
company_id,
task_id,
model_events=request.model_events,
)[0]
key = ScalarKeyEnum.iter
@@ -355,7 +381,7 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task_or_model.get_index_company(),
task_id=task_id,
task_ids=[task_id],
metric_variants=metric_variants,
)
@@ -391,16 +417,18 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
@endpoint("events.get_scalar_metric_data")
def get_scalar_metric_data(call, company_id, request: GetScalarMetricDataRequest):
task_id = request.task
metric = request.metric
scroll_id = request.scroll_id
no_scroll = request.no_scroll
model_events = request.model_events
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
result = event_bll.get_task_events(
task_or_model.get_index_company(),
@@ -420,9 +448,9 @@ def get_scalar_metric_data(call, company_id, _):
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_task_latest_scalar_values")
def get_task_latest_scalar_values(call, company_id, request: TaskRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
@@ -464,13 +492,17 @@ def scalar_metrics_iter_histogram(
def _get_task_or_model_index_companies(
company_id: str, task_ids: Sequence[str], model_events=False,
company_id: str,
task_ids: Sequence[str],
model_events=False,
) -> TaskCompanies:
"""
Returns lists of tasks grouped by company
"""
tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events,
company_id,
task_ids,
model_events=model_events,
)
unique_ids = set(task_ids)
@@ -504,36 +536,53 @@ def multi_task_scalar_metrics_iter_histogram(
),
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
)
def _get_single_value_metrics_response(
companies: TaskCompanies, value_metrics: Mapping[str, dict]
) -> Sequence[dict]:
task_names = {
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
}
return [
{"task": task_id, "task_name": task_names.get(task_id), "values": values}
for task_id, values in value_metrics.items()
]
@endpoint("events.get_task_single_value_metrics")
def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest
):
res = event_bll.metrics.get_task_single_value_metrics(
companies=_get_task_or_model_index_companies(
company_id, request.tasks, request.model_events
),
companies = _get_task_or_model_index_companies(
company_id, request.tasks, request.model_events
)
call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()]
tasks=_get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies,
metric_variants=_get_metric_variants_from_request(request.metrics),
),
)
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_multi_task_plots")
def get_multi_task_plots_v1_7(call, company_id, request: LegacyMultiTaskEventsRequest):
task_ids = request.tasks
iters = request.iters
scroll_id = request.scroll_id
companies = _get_task_or_model_index_companies(company_id, task_ids)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
list(companies),
task_ids,
company_id=list(companies),
task_id=task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
size=10000,
@@ -558,11 +607,12 @@ def get_multi_task_plots_v1_7(call, company_id, _):
def _get_multitask_plots(
companies: TaskCompanies,
last_iters: int,
metrics: MetricVariants = None,
last_iters_per_task_metric: bool,
request_metrics: Sequence[ApiMetrics] = None,
scroll_id=None,
no_scroll=True,
model_events=False,
) -> Tuple[dict, int, str]:
metrics = _get_metric_variants_from_request(request_metrics)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
@@ -575,6 +625,10 @@ def _get_multitask_plots(
sort=[{"iter": {"order": "desc"}}],
scroll_id=scroll_id,
no_scroll=no_scroll,
size=config.get(
"services.events.events_retrieval.multi_plots_batch_size", 1000
),
last_iters_per_task_metric=last_iters_per_task_metric,
)
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=last_iters, task_names=task_names
@@ -582,23 +636,18 @@ def _get_multitask_plots(
return return_events, result.total_events, result.next_scroll_id
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
@endpoint("events.get_multi_task_plots", min_version="1.8")
def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
companies = _get_task_or_model_index_companies(
company_id, task_ids, model_events=model_events
company_id, request.tasks, model_events=request.model_events
)
return_events, total_events, next_scroll_id = _get_multitask_plots(
companies=companies,
last_iters=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
model_events=model_events,
last_iters=request.iters,
scroll_id=request.scroll_id,
no_scroll=request.no_scroll,
last_iters_per_task_metric=request.last_iters_per_task_metric,
request_metrics=request.metrics,
)
call.result.data = dict(
plots=return_events,
@@ -608,11 +657,11 @@ def get_multi_task_plots(call, company_id, _):
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
@endpoint("events.get_task_plots")
def get_task_plots_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -730,11 +779,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@endpoint("events.debug_images")
def get_debug_images_v1_7(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -767,15 +816,17 @@ def get_debug_images_v1_7(call, company_id, _):
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
@endpoint("events.debug_images", min_version="1.8")
def get_debug_images_v1_8(call, company_id, request: LegacyMetricEventsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
model_events = request.model_events
tasks_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
result = event_bll.get_task_events(
tasks_or_model.get_index_company(),
@@ -831,7 +882,9 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
)
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task_or_model.get_index_company(),
@@ -853,7 +906,9 @@ def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
)
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(),
@@ -866,11 +921,14 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
@endpoint(
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
"events.get_plot_sample",
request_data_model=GetMetricSamplesRequest,
)
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_samples_for_metric(
company_id=task_or_model.get_index_company(),
@@ -885,11 +943,14 @@ def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
@endpoint(
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
"events.next_plot_sample",
request_data_model=NextHistorySampleRequest,
)
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(),
@@ -904,7 +965,9 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task_or_models = _assert_task_or_model_exists(
company_id, request.tasks, model_events=request.model_events,
company_id,
request.tasks,
model_events=request.model_events,
)
res = event_bll.metrics.get_task_metrics(
task_or_models[0].get_index_company(),
@@ -916,12 +979,35 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
@endpoint("events.get_multi_task_metrics")
def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest):
companies = _get_task_or_model_index_companies(
company_id, request.tasks, model_events=request.model_events
)
if not companies:
return {"metrics": []}
metrics = event_bll.metrics.get_multi_task_metrics(
companies=companies, event_type=request.event_type
)
res = [
{
"metric": m,
"variants": sorted(vars_),
}
for m, vars_ in metrics.items()
]
call.result.data = {"metrics": sorted(res, key=itemgetter("metric"))}
@endpoint("events.delete_for_task")
def delete_for_task(call, company_id, request: TaskRequest):
task_id = request.task
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
@@ -929,9 +1015,9 @@ def delete_for_task(call, company_id, _):
)
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
@endpoint("events.delete_for_model")
def delete_for_model(call: APICall, company_id: str, request: ModelRequest):
model_id = request.model
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
@@ -946,7 +1032,9 @@ def delete_for_model(call: APICall, company_id: str, _):
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,
@@ -960,12 +1048,12 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
def _get_top_iter_unique_events_per_task(
events, max_iters: int, task_names: Mapping[str, str]
):
key = itemgetter("metric", "variant", "task", "iter")
key_fields = ("metric", "variant", "task")
unique_events = itertools.chain.from_iterable(
itertools.islice(group, max_iters)
for _, group in itertools.groupby(
sorted(events, key=key, reverse=True), key=key
sorted(events, key=itemgetter(*(key_fields + ("iter",))), reverse=True),
key=itemgetter(*key_fields),
)
)
@@ -1047,7 +1135,7 @@ def scalar_metrics_iter_raw(
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task_or_model.get_index_company(),
task_id=task_id,
task_ids=[task_id],
metric_variants=metric_variants,
)

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from functools import partial
from typing import Sequence
from typing import Sequence, Union
from mongoengine import Q, EmbeddedDocument
@@ -21,12 +21,18 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
ModelRequest,
TaskRequest,
UpdateForTaskRequest,
UpdateModelRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.model import validate_id
@@ -45,6 +51,7 @@ from apiserver.database.utils import (
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -59,32 +66,37 @@ org_bll = OrgBLL()
project_bll = ProjectBLL()
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
conform_output_tags(call, model_data)
unescape_metadata(call, model_data)
Metadata.escape_query_parameters(call)
@endpoint("models.get_by_id")
def get_by_id(call: APICall, company_id, request: ModelRequest):
model_id = request.model
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many(
company=company_id,
query_dict=call.data,
query_dict=call_data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
"no such public or company model",
id=model_id,
company=company_id,
)
conform_output_tags(call, models[0])
unescape_metadata(call, models[0])
conform_model_data(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call: APICall, company_id, _):
@endpoint("models.get_by_task_id")
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = call.data["task"]
task_id = request.task
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
@@ -99,35 +111,38 @@ def get_by_task_id(call: APICall, company_id, _):
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
"no such public or company model",
id=model_id,
company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
unescape_metadata(call, model_dict)
conform_model_data(call, model_dict)
call.result.data = {"model": model_dict}
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
query_dict=call_data,
allow_public=request.allow_public,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
conform_model_data(call, models)
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
for model in models:
model["stats"] = stats.get(model["id"])
@@ -138,29 +153,28 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
@endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
company=company_id, query_dict=call_data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
conform_model_data(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
@endpoint("models.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
parameters=call_data,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
conform_model_data(call, models)
call.result.data = {"models": models, **ret_params}
@@ -183,7 +197,7 @@ create_fields = {
"project": Project,
"parent": Model,
"framework": None,
"design": None,
"design": dict,
"labels": dict,
"ready": None,
"metadata": list,
@@ -212,7 +226,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Model,
project=project,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
@@ -220,32 +234,33 @@ def _update_cached_tags(company: str, project: str, fields: dict):
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company, Tags.Model, projects=projects,
company,
Tags.Model,
projects=projects,
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
@endpoint("models.update_for_task")
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
override_model_id = call.data.get("override_model_id")
task_id = request.task
uri = request.uri
iteration = request.iteration
override_model_id = request.override_model_id
if not (uri or override_model_id) or (uri and override_model_id):
raise errors.bad_request.MissingRequiredFields(
"exactly one field is required", fields=("uri", "override_model_id")
)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("models", "execution", "name", "status", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
@@ -283,6 +298,8 @@ def update_for_task(call: APICall, company_id, _):
id=database.utils.id(),
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
user=call.identity.user,
company=company_id,
project=task.project,
@@ -301,6 +318,7 @@ def update_for_task(call: APICall, company_id, _):
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
models__output=[
ModelItem(
@@ -320,7 +338,6 @@ def update_for_task(call: APICall, company_id, _):
response_data_model=CreateModelResponse,
)
def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company_id = ""
@@ -331,7 +348,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
validate_task(company_id, call.identity, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
@@ -345,6 +362,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
company=company_id,
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
**fields,
)
model.save()
@@ -359,7 +378,7 @@ def prepare_update_fields(call, company_id, fields: dict):
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(company_id, fields)
validate_task(company_id, call.identity, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -389,13 +408,16 @@ def prepare_update_fields(call, company_id, fields: dict):
return fields
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
def validate_task(company_id: str, identity: Identity, fields: dict):
task_id = fields["task"]
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
)
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
@endpoint("models.edit", response_data_model=UpdateResponse)
def edit(call: APICall, company_id, request: UpdateModelRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
@@ -410,16 +432,24 @@ def edit(call: APICall, company_id, _):
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
iteration = request.iteration
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
if fields:
now = datetime.utcnow()
fields.update(
last_change=now,
last_changed_by=call.identity.user,
)
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
fields.update(last_update=now)
updated = model.update(upsert=False, **fields)
if updated:
@@ -428,31 +458,38 @@ def edit(call: APICall, company_id, _):
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(company_id, project=model.project, fields=fields)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
conform_model_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
def _update_model(call: APICall, company_id, model_id):
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
now = datetime.utcnow()
updated_count, updated_fields = Model.safe_update(
company_id,
model.id,
data,
injected_update=dict(
last_change=now,
last_changed_by=call.identity.user,
),
)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
model.update(upsert=False, last_update=now)
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
@@ -461,16 +498,13 @@ def _update_model(call: APICall, company_id, model_id=None):
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
conform_output_tags(call, updated_fields)
unescape_metadata(call, updated_fields)
conform_model_data(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint("models.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UpdateModelRequest):
call.result.data_model = _update_model(call, company_id, model_id=request.model)
@endpoint(
@@ -482,7 +516,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
@@ -501,7 +535,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial(
ModelBLL.publish_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),
@@ -524,7 +558,11 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
@endpoint("models.delete", request_data_model=DeleteModelRequest)
def delete(call: APICall, company_id, request: DeleteModelRequest):
del_count, model = ModelBLL.delete_model(
model_id=request.model, company_id=company_id, force=request.force
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
delete_external_artifacts=request.delete_external_artifacts,
)
if del_count:
_reset_cached_tags(
@@ -541,7 +579,13 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
)
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
func=partial(
ModelBLL.delete_model,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
@@ -565,7 +609,10 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
)
def archive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
func=partial(
ModelBLL.archive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
@@ -580,7 +627,10 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
func=partial(
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[
@@ -593,7 +643,11 @@ def unarchive_many(call: APICall, company_id, request: BatchRequest):
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=True,
)
@@ -602,7 +656,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=False,
)
@@ -625,30 +683,51 @@ def move(call: APICall, company_id: str, request: MoveRequest):
}
@endpoint("models.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
now = datetime.utcnow()
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=datetime.utcnow(),
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
return {
"updated": Metadata.delete_metadata(
model, keys=request.keys, last_update=datetime.utcnow()
model,
keys=request.keys,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}

View File

@@ -1,21 +1,45 @@
import csv
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from operator import itemgetter
from typing import Mapping, Type
from typing import Mapping, Type, Sequence, Optional, Callable, Hashable
from flask import stream_with_context
from mongoengine import Q
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
from apiserver.apierrors import errors
from apiserver.apimodels.organization import (
TagsRequest,
EntitiesCountRequest,
DownloadForGetAllRequest,
EntityType,
PrepareDownloadForGetAllRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.redis_manager import redman
from apiserver.service_repo import endpoint, APICall
from apiserver.services.models import conform_model_data
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
conform_task_data,
)
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
redis = redman.connection("apiserver")
conf = config.get("services.organization")
@endpoint("organization.get_tags", request_data_model=TagsRequest)
@@ -24,7 +48,10 @@ def get_tags(call: APICall, company, request: TagsRequest):
ret = defaultdict(set)
for entity in Tags.Model, Tags.Task:
tags = org_bll.get_tags(
company, entity, include_system=request.include_system, filter_=filter_dict,
company,
entity,
include_system=request.include_system,
filter_=filter_dict,
)
for field, vals in tags.items():
ret[field] |= vals
@@ -105,3 +132,203 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
)
call.result.data = ret
def _get_download_getter_fn(
company: str,
call: APICall,
call_data: dict,
allow_public: bool,
entity_type: EntityType,
) -> Optional[Callable[[int, int], Sequence[dict]]]:
def get_task_data() -> Sequence[dict]:
tasks = Task.get_many_with_join(
company=company,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=allow_public,
)
conform_task_data(call, tasks)
return tasks
def get_model_data() -> Sequence[dict]:
models = Model.get_many_with_join(
company=company,
query_dict=call_data,
allow_public=allow_public,
)
conform_model_data(call, models)
return models
if entity_type == EntityType.task:
call_data = escape_execution_parameters(call_data)
get_fn = get_task_data
elif entity_type == EntityType.model:
call_data = Metadata.escape_query_parameters(call_data)
get_fn = get_model_data
else:
raise errors.bad_request.ValidationError(
f"Unsupported entity type: {str(entity_type)}"
)
def getter(page: int, page_size: int) -> Sequence[dict]:
call_data.pop("scroll_id", None)
call_data.pop("start", None)
call_data.pop("size", None)
call_data.pop("refresh_scroll", None)
call_data["page"] = page
call_data["page_size"] = page_size
return get_fn()
return getter
@endpoint("organization.prepare_download_for_get_all")
def prepare_download_for_get_all(
call: APICall, company: str, request: PrepareDownloadForGetAllRequest
):
# validate input params
field_names = set()
for fm in request.field_mappings:
name = fm.name or fm.field
if name in field_names:
raise errors.bad_request.ValidationError(
f"Field_name appears more than once in field_mappings: {str(name)}"
)
field_names.add(name)
if fm.values:
value_keys = set()
for v in fm.values:
if v.key in value_keys:
raise errors.bad_request.ValidationError(
f"Value key appears more than once in field_mappings: {str(v.key)}"
)
value_keys.add(v.key)
getter = _get_download_getter_fn(
company,
call,
call_data=call.data.copy(),
allow_public=request.allow_public,
entity_type=request.entity_type,
)
# retrieve one element just to make sure that there are no issues with the call parameters
if getter:
getter(0, 1)
redis.setex(
f"get_all_download_{call.id}",
int(conf.get("download.redis_timeout_sec", 300)),
json.dumps(call.data),
)
call.result.data = dict(prepare_id=call.id)
@endpoint("organization.download_for_get_all")
def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest):
request_data = redis.get(f"get_all_download_{request.prepare_id}")
if not request_data:
raise errors.bad_request.InvalidId(
f"prepare ID not found", prepare_id=request.prepare_id
)
try:
call_data = json.loads(request_data)
request = PrepareDownloadForGetAllRequest(**call_data)
except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
class SingleLine:
@staticmethod
def write(line: str) -> str:
return line
def generate():
field_mappings = {
mapping.get("name", mapping["field"]): {
"field_path": mapping["field"].split("."),
"values": {
v.get("key"): v.get("value")
for v in (mapping.get("values") or [])
},
}
for mapping in call_data.get("field_mappings", [])
}
get_fn = _get_download_getter_fn(
company,
call,
call_data=call_data,
allow_public=request.allow_public,
entity_type=request.entity_type,
)
if not get_fn:
yield csv.writer(SingleLine()).writerow(field_mappings)
return
def get_entity_field_as_str(
data: dict, field_path: Sequence[str], values: Mapping
) -> str:
val = nested_get(data, field_path, "")
if isinstance(val, dict):
val = val.get("id", "")
if values and isinstance(val, Hashable):
val = values.get(val, val)
return str(val)
def get_projected_fields(data: dict) -> Sequence[str]:
return [
get_entity_field_as_str(
data, field_path=m["field_path"], values=m["values"]
)
for m in field_mappings.values()
]
with ThreadPoolExecutor(1) as pool:
page = 0
page_size = int(conf.get("download.batch_size", 500))
items_left = int(conf.get("download.max_download_items", 1000))
future = pool.submit(get_fn, page, min(page_size, items_left))
while items_left > 0:
result = future.result()
if not result:
break
items_left -= len(result)
page += 1
if items_left > 0:
future = pool.submit(get_fn, page, min(page_size, items_left))
with StringIO() as fp:
writer = csv.writer(fp)
if page == 1:
fp.write("\ufeff") # utf-8 signature
writer.writerow(field_mappings)
writer.writerows(get_projected_fields(r) for r in result)
yield fp.getvalue()
if page == 0:
yield csv.writer(SingleLine()).writerow(field_mappings)
def get_project_name() -> Optional[str]:
projects = call_data.get("project")
if not projects or not isinstance(projects, (list, str)):
return
if isinstance(projects, list):
if len(projects) > 1:
return
projects = projects[0]
if projects is None:
return "root"
project: Project = Project.objects(id=projects).only("basename").first()
if not project:
return
return project.basename[: conf.get("download.max_project_name_length", 60)]
call.result.filename = "-".join(
filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))
)
call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate())

View File

@@ -1,17 +1,28 @@
import re
from functools import partial
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
import attr
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineRequest,
DeleteRunsRequest,
)
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task
from apiserver.bll.task.task_operations import enqueue_task, delete_task
from apiserver.bll.util import run_batch_operation
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
queue_bll = QueueBLL()
def _update_task_name(task: Task):
@@ -31,9 +42,46 @@ def _update_task_name(task: Task):
task.update(name=new_name)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
@endpoint("pipelines.delete_runs")
def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
existing_runs = set(
Task.objects(project=request.project, type=TaskType.controller).scalar("id")
)
if not existing_runs.difference(request.ids):
raise CannotRemoveAllRuns(project=request.project)
# make sure that only controller tasks are deleted
ids = existing_runs.intersection(request.ids)
if not ids:
return dict(succeeded=[], failed=[])
results, failures = run_batch_operation(
func=partial(
delete_task,
company_id=company_id,
identity=call.identity,
move_to_trash=False,
force=True,
return_file_urls=False,
delete_output_models=True,
status_message="",
status_reason="Pipeline run deleted",
delete_external_artifacts=True,
),
ids=list(ids),
)
succeeded = []
if results:
for _id, (deleted, task, cleanup_res) in results:
succeeded.append(
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
)
call.result.data = dict(succeeded=succeeded, failed=failures)
@endpoint("pipelines.start_pipeline")
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
@@ -60,10 +108,19 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
extra = {}
if request.verify_watched_queue and queued:
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
call.result.data = dict(
pipeline=task.id,
enqueued=bool(queued),
**extra,
)

View File

@@ -7,18 +7,20 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
DeleteRequest,
GetParamsRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
ProjectsGetRequest,
DeleteRequest,
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
ProjectChildrenType,
GetUniqueMetricsRequest,
ProjectUserNamesRequest,
EntityTypeEnum,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
@@ -29,8 +31,9 @@ from apiserver.bll.project.project_cleanup import (
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import TaskType
from apiserver.database.model.task.task import TaskType, Task
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
@@ -56,13 +59,12 @@ create_fields = {
}
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
@endpoint("projects.get_by_id")
def get_by_id(call: APICall, company: str, request: ProjectRequest):
project_id = request.project
with translate_errors_context():
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
query = Q(id=project_id) & get_company_or_none_constraint(company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@@ -99,19 +101,37 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
data["parent"] = [None]
def _get_project_stats_filter(request: ProjectsGetRequest) -> Tuple[Optional[dict], bool]:
def _get_project_stats_filter(
request: ProjectsGetRequest,
) -> Tuple[Optional[dict], bool]:
if request.include_stats_filter or not request.children_type:
return request.include_stats_filter, request.search_hidden
if request.children_tags_filter:
stats_filter = {"tags": request.children_tags_filter}
elif request.children_tags:
stats_filter = {"tags": request.children_tags}
else:
stats_filter = {}
if request.children_type == ProjectChildrenType.pipeline:
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}, True
return (
{
**stats_filter,
"system_tags": [pipeline_tag],
"type": [TaskType.controller],
},
True,
)
if request.children_type == ProjectChildrenType.report:
return {"system_tags": [reports_tag], "type": [TaskType.report]}, True
return request.include_stats_filter, request.search_hidden
return (
{**stats_filter, "system_tags": [reports_tag], "type": [TaskType.report]},
True,
)
return stats_filter, request.search_hidden
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data
conform_tag_fields(call, data)
@@ -126,8 +146,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
data,
shallow_search=request.shallow_search,
)
selected_project_ids = None
if request.active_users or request.children_type:
@@ -137,6 +159,8 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=requested_ids,
allow_public=allow_public,
children_type=request.children_type,
children_tags=request.children_tags,
children_tags_filter=request.children_tags_filter,
)
if not ids:
return {"projects": []}
@@ -174,19 +198,21 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_output_tags(call, projects)
project_ids = list({project["id"] for project in projects})
stats_filter, stats_search_hidden = _get_project_stats_filter(request)
if request.check_own_contents:
if request.children_type == ProjectChildrenType.dataset:
contents = project_bll.calc_own_datasets(
company=company_id,
project_ids=project_ids,
filter_=request.include_stats_filter,
filter_=stats_filter,
users=request.active_users,
)
else:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=project_ids,
filter_=_get_project_stats_filter(request)[0],
filter_=stats_filter,
specific_state=request.stats_for_state,
users=request.active_users,
)
@@ -199,19 +225,18 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id,
project_ids=project_ids,
include_children=request.stats_with_children,
filter_=request.include_stats_filter,
filter_=stats_filter,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
else:
filter_, search_hidden = _get_project_stats_filter(request)
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=project_ids,
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=search_hidden,
filter_=filter_,
search_hidden=stats_search_hidden,
filter_=stats_filter,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
@@ -222,7 +247,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id, project_ids=project_ids, users=request.active_users,
company=company_id,
project_ids=project_ids,
users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
@@ -231,15 +258,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
@endpoint("projects.get_all")
def get_all(call: APICall):
def get_all(call: APICall, company: str, _):
data = call.data
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
data,
shallow_search=data.get("shallow_search", False),
)
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
company=company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
@@ -253,9 +281,11 @@ def get_all(call: APICall):
@endpoint(
"projects.create", required_fields=["name"], response_data_model=IdResponse,
"projects.create",
required_fields=["name"],
response_data_model=IdResponse,
)
def create(call: APICall):
def create(call: APICall, company: str, _):
identity = call.identity
with translate_errors_context():
@@ -264,15 +294,15 @@ def create(call: APICall):
return IdResponse(
id=ProjectBLL.create(
user=identity.user, company=identity.company, **fields,
user=identity.user,
company=company,
**fields,
)
)
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
@endpoint("projects.update", response_data_model=UpdateResponse)
def update(call: APICall, company: str, request: ProjectRequest):
"""
update
@@ -285,9 +315,7 @@ def update(call: APICall):
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields, validate=True)
updated = ProjectBLL.update(
company=call.identity.company, project_id=call.data["project"], **fields
)
updated = ProjectBLL.update(company=company, project_id=request.project, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -339,28 +367,30 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
project_id=request.project,
force=request.force,
delete_contents=request.delete_contents,
delete_external_artifacts=request.delete_external_artifacts,
)
_reset_cached_tags(company_id, projects=list(affected_projects))
call.result.data = {**attr.asdict(res)}
@endpoint(
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
"projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest
)
def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
call: APICall, company_id: str, request: GetUniqueMetricsRequest
):
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
ids=request.ids,
model_metrics=request.model_metrics,
)
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
@endpoint("projects.get_model_metadata_keys")
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
@@ -387,6 +417,8 @@ def get_model_metadata_values(
key=request.key,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
@@ -400,7 +432,6 @@ def get_model_metadata_values(
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
@@ -431,6 +462,9 @@ def get_hyperparam_values(
name=request.name,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
pattern=request.pattern,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
@@ -482,7 +516,11 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidProjectId,
enabled=True,
)
@@ -491,7 +529,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidProjectId,
enabled=False,
)
@@ -504,10 +546,23 @@ def get_task_parents(
call: APICall, company_id: str, request: ProjectTaskParentsRequest
):
call.result.data = {
"parents": project_bll.get_task_parents(
"parents": ProjectBLL.get_task_parents(
company_id,
projects=request.projects,
include_subprojects=request.include_subprojects,
state=request.tasks_state,
name=request.task_name,
)
}
@endpoint("projects.get_user_names")
def get_user_names(call: APICall, company_id: str, request: ProjectUserNamesRequest):
call.result.data = {
"users": ProjectBLL.get_entity_users(
company_id,
entity_cls=Model if request.entity == EntityTypeEnum.model else Task,
projects=request.projects,
include_subprojects=request.include_subprojects,
)
}

View File

@@ -1,3 +1,5 @@
from typing import Union, Sequence
from mongoengine import Q
from apiserver.apimodels.base import UpdateResponse
@@ -21,6 +23,7 @@ from apiserver.apimodels.queues import (
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.queue.queue_bll import MOVE_FIRST, MOVE_LAST
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.model.task.task import Task
@@ -38,14 +41,18 @@ worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
conform_output_tags(call, queue_data)
unescape_metadata(call, queue_data)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
conform_queue_data(call, queue_dict)
call.result.data = {"queue": queue_dict}
@@ -76,16 +83,15 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
queues = queue_bll.get_queue_infos(
company_id=company,
query_dict=call.data,
query=_hidden_query(call.data),
query_dict=call_data,
query=_hidden_query(call_data),
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
conform_queue_data(call, queues)
call.result.data = {"queues": queues, **ret_params}
@@ -93,16 +99,15 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
def get_all(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
queues = queue_bll.get_all(
company_id=company,
query_dict=call.data,
query=_hidden_query(call.data),
query_dict=call_data,
query=_hidden_query(call_data),
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
conform_queue_data(call, queues)
call.result.data = {"queues": queues, **ret_params}
@@ -134,8 +139,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
conform_queue_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -167,7 +171,7 @@ def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
if entry:
data = {"entry": entry.to_proper_dict()}
if request.get_task_info:
task = Task.objects(id=entry.task).first()
task = Task.objects(id=entry.task).only("company", "user").first()
if task:
data["task_info"] = {"company": task.company, "user": task.user}
@@ -195,7 +199,7 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: max(0, p - req_model.count),
move_count=-req_model.count,
)
)
@@ -212,7 +216,7 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: max(0, p + req_model.count),
move_count=req_model.count,
)
)
@@ -229,7 +233,7 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: 0,
move_count=MOVE_FIRST,
)
)
@@ -246,7 +250,7 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: -1,
move_count=MOVE_LAST,
)
)

View File

@@ -3,6 +3,8 @@ from datetime import datetime
from itertools import chain
from typing import Sequence
from mongoengine import Q
from apiserver.apimodels.reports import (
CreateReportRequest,
UpdateReportRequest,
@@ -17,6 +19,10 @@ from apiserver.apimodels.reports import (
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.database.model.model import Model
from apiserver.service_repo.auth import Identity
from apiserver.services.models import conform_model_data
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
@@ -31,11 +37,12 @@ from apiserver.services.events import (
_get_metrics_response,
_get_metric_variants_from_request,
_get_multitask_plots,
_get_single_value_metrics_response,
)
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
unprepare_from_saved,
conform_task_data,
)
org_bll = OrgBLL()
@@ -52,15 +59,15 @@ update_fields = {
}
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=identity,
only=only_fields,
requires_write_access=requires_write_access,
)
if task.type != TaskType.report:
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
@@ -71,7 +78,10 @@ def _assert_report(company_id, task_id, only_fields=None, requires_write_access=
@endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task, company_id=company_id, only_fields=("status",),
task_id=request.task,
company_id=company_id,
identity=call.identity,
only_fields=("status",),
)
partial_update_dict = {
@@ -168,9 +178,21 @@ def _delete_reports_project_if_empty(project_id):
def get_all_ex(call: APICall, company_id, request: GetAllRequest):
call_data = call.data
call_data["type"] = TaskType.report
call_data["include_subprojects"] = True
process_include_subprojects(call_data)
# bring projects one level down in case not the .reports project was passed
if "project" in call_data:
project_ids = call_data["project"]
if not isinstance(project_ids, list):
project_ids = [project_ids]
query = Q(parent__in=project_ids) | Q(id__in=project_ids)
project_ids = Project.objects(query & Q(basename=reports_project_name)).scalar(
"id"
)
if not project_ids:
return {"tasks": []}
call_data["project"] = list(project_ids)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
@@ -178,7 +200,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest):
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@@ -198,26 +220,38 @@ def _get_task_metrics_from_request(
@endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
call_data = escape_execution_parameters(call)
if request.model_events:
entity_cls = Model
conform_data = conform_model_data
else:
entity_cls = Task
conform_data = conform_task_data
call_data = escape_execution_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
tasks = entity_cls.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_data(call, tasks)
res = {"tasks": tasks, **ret_params}
if not (
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
request.debug_images
or request.plots
or request.scalar_metrics_iter_histogram
or request.single_value_metrics
):
return res
task_ids = [task["id"] for task in tasks]
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
companies = _get_task_or_model_index_companies(
company_id, task_ids=task_ids, model_events=request.model_events
)
if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events(
companies={
@@ -234,7 +268,8 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
res["plots"] = _get_multitask_plots(
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
request_metrics=request.plots.metrics,
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
)[0]
if request.scalar_metrics_iter_histogram:
@@ -249,6 +284,14 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
),
)
if request.single_value_metrics:
res["single_value_metrics"] = _get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
),
)
call.result.data = res
@@ -260,7 +303,10 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
)
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
user_id = call.identity.user
project_name = request.project_name
@@ -291,10 +337,13 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
@endpoint(
"reports.publish", response_data_model=UpdateResponse,
"reports.publish",
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
updates = ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
@@ -309,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
archived = task.update(
status_message=request.message,
status_reason="",
@@ -323,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
unarchived = task.update(
status_message=request.message,
status_reason="",
@@ -349,7 +402,10 @@ def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
if (
task.status != TaskStatus.created

View File

@@ -3,7 +3,11 @@ from datetime import datetime
from pyhocon.config_tree import NoneValue
from apiserver.apierrors import errors
from apiserver.apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
from apiserver.apimodels.server import (
ReportStatsOptionRequest,
ReportStatsOptionResponse,
GetConfigRequest,
)
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config_repo import config
from apiserver.config.info import get_version, get_build_number, get_commit_number
@@ -22,8 +26,8 @@ def get_stats(call: APICall):
@endpoint("server.config")
def get_config(call: APICall):
path = call.data.get("path")
def get_config(call: APICall, _, request: GetConfigRequest):
path = request.path
if path:
c = dict(config.get(path))
else:

View File

@@ -65,6 +65,9 @@ from apiserver.apimodels.tasks import (
CompletedRequest,
CompletedResponse,
GetAllReq,
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@@ -74,7 +77,6 @@ from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
@@ -98,10 +100,17 @@ from apiserver.bll.task.task_operations import (
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.bll.task.utils import (
update_task,
get_task_for_update,
deleted_prefix,
get_many_tasks_for_writing,
get_task_with_write_access,
)
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
@@ -110,8 +119,13 @@ from apiserver.database.model.task.task import (
ModelItem,
TaskModelTypes,
)
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.database.utils import (
get_fields_attr,
parse_from_call,
get_options,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -136,33 +150,36 @@ org_bll = OrgBLL()
project_bll = ProjectBLL()
def _assert_writable_tasks(
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
) -> Sequence[Task]:
tasks = get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=Q(id__in=ids),
only=only,
)
missing_ids = set(ids) - {t.id for t in tasks}
if missing_ids:
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
return tasks
def set_task_status_from_call(
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
request: UpdateRequest,
company_id: str,
identity: Identity,
new_status=None,
**set_fields,
) -> dict:
fields_resolver = SetFieldsResolver(set_fields)
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
request.task,
company_id=company_id,
only=tuple(
{"status", "project", "started", "duration"} | fields_resolver.get_names()
),
requires_write_access=True,
identity=identity,
only=("id", "status", "project"),
)
if "duration" not in fields_resolver.get_names():
if new_status == Task.started:
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
elif new_status in (
TaskStatus.completed,
TaskStatus.failed,
TaskStatus.stopped,
):
fields_resolver.add_fields(
duration=int((task.started - datetime.utcnow()).total_seconds())
if task.started
else 0
)
status_reason = request.status_reason
status_message = request.status_message
force = request.force
@@ -172,27 +189,29 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute(**fields_resolver.get_fields(task))
user_id=identity.user,
).execute(**set_fields)
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, allow_public=True
)
def get_by_id(call: APICall, company_id, request: TaskRequest):
task = TaskBLL.assert_exists(
company_id,
task_ids=request.task,
allow_public=True,
)[0]
task_dict = task.to_proper_dict()
unprepare_from_saved(call, task_dict)
conform_task_data(call, task_dict)
call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall) -> dict:
if not call.data:
return call.data
def escape_execution_parameters(call_data: dict) -> dict:
if not call_data:
return call_data
keys = list(call.data)
keys = list(call_data)
call_data = {
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
safe_key: call_data[key] for key, safe_key in zip(keys, escape_paths(keys))
}
projection = Task.get_projection(call_data)
@@ -219,9 +238,7 @@ def _hidden_query(data: dict) -> Q:
@endpoint("tasks.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllReq):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
@@ -231,29 +248,29 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
company=company_id,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
@endpoint("tasks.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many(
@@ -264,7 +281,7 @@ def get_all(call: APICall, company_id, _):
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@@ -293,7 +310,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
@@ -311,7 +328,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial(
stop_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
@@ -334,7 +351,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
@@ -347,13 +364,21 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
response_data_model=StartedResponse,
)
def started(call: APICall, company_id, req_model: UpdateRequest):
started_update = {}
if Task.objects(id=req_model.task, started=None).only("id"):
# this is the fix for older versions putting started to None on reset
started_update["started"] = datetime.utcnow()
else:
# don't override a previous, smaller "started" field value
started_update["min__started"] = datetime.utcnow()
res = StartedResponse(
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
**started_update,
)
)
res.started = res.updated
@@ -368,7 +393,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.failed,
)
)
@@ -382,8 +407,9 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.closed)
identity=call.identity,
new_status=TaskStatus.closed,
)
)
@@ -395,18 +421,19 @@ create_fields = {
"error": None,
"comment": None,
"parent": Task,
"project": None,
"project": Project,
"input": None,
"models": None,
"container": None,
"container": dict,
"output_dest": None,
"execution": None,
"hyperparams": None,
"configuration": None,
"hyperparams": dict,
"configuration": dict,
"script": None,
"runtime": None,
"runtime": dict,
}
dict_fields_paths = [("execution", "model_labels"), "container"]
@@ -430,7 +457,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
return fields
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
@@ -447,13 +474,17 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
for data in tasks_data:
params_unprepare_from_saved(
fields=data, copy_to_legacy=need_legacy_params,
fields=data,
copy_to_legacy=need_legacy_params,
)
artifacts_unprepare_from_saved(fields=data)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
call: APICall,
valid_fields=None,
output=None,
previous_task: Task = None,
):
valid_fields = valid_fields if valid_fields is not None else create_fields
t_fields = task_fields
@@ -511,7 +542,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Task,
project=project,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
@@ -580,11 +611,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
task_id = req_model.task
with translate_errors_context():
task = Task.get_for_writing(
id=task_id, company=company_id, _only=["id", "project"]
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("id", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
@@ -596,7 +628,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
),
)
if updated_count:
@@ -608,7 +641,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
company_id, project=task.project, fields=updated_fields
)
update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields)
conform_task_data(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -620,11 +653,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
requirements = req_model.requirements
with translate_errors_context():
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
req_model.task,
company_id=company_id,
identity=call.identity,
only=("status", "script"),
requires_write_access=True,
)
if not task.script:
raise errors.bad_request.MissingTaskFields(
@@ -650,8 +683,11 @@ def update_batch(call: APICall, company_id, _):
items = {i["task"]: i for i in items}
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=company_id, query=Q(id__in=list(items))
for t in _assert_writable_tasks(
identity=call.identity,
company_id=company_id,
ids=list(items),
only=("id", "project"),
)
}
@@ -670,7 +706,8 @@ def update_batch(call: APICall, company_id, _):
if not partial_update_dict:
continue
partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
last_change=now,
last_changed_by=call.identity.user,
)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
@@ -704,9 +741,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
force = req_model.force
with translate_errors_context():
task = Task.get_for_writing(id=task_id, company=company_id)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
)
if not force and task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
@@ -763,14 +802,15 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
company_id, project=task.project, fields=fixed_fields
)
update_project_time(fields.get("project"))
unprepare_from_saved(call, fields)
conform_task_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
@endpoint(
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
"tasks.get_hyper_params",
request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
@@ -785,7 +825,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
@@ -799,7 +839,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
@@ -808,7 +848,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
"tasks.get_configurations",
request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
tasks_params = HyperParams.get_configurations(
@@ -823,7 +864,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
@endpoint(
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
"tasks.get_configuration_names",
request_data_model=GetConfigurationNamesRequest,
)
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
@@ -844,7 +886,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
@@ -860,7 +902,7 @@ def delete_configuration(
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
force=request.force,
@@ -877,7 +919,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -902,7 +944,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial(
enqueue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -930,33 +972,35 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
@endpoint(
"tasks.dequeue",
request_data_model=UpdateRequest,
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, request: UpdateRequest):
def dequeue(call: APICall, company_id, request: DequeueRequest):
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
new_status=request.new_status,
)
call.result.data_model = DequeueResponse(dequeued=dequeued, **res)
@endpoint(
"tasks.dequeue_many",
request_data_model=TaskBatchRequest,
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
new_status=request.new_status,
),
ids=request.ids,
)
@@ -976,7 +1020,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -1004,7 +1048,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
func=partial(
reset_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -1041,16 +1085,25 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
tasks = TaskBLL.assert_exists(
company_id,
task_ids=request.tasks,
only=("id", "execution", "status", "project", "system_tags", "enqueue_status"),
)
archived = 0
tasks = _assert_writable_tasks(
company_id,
call.identity,
ids=request.tasks,
only=(
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
)
for task in tasks:
archived += archive_task(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -1069,7 +1122,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
archive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1091,7 +1144,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
unarchive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1110,7 +1163,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1132,7 +1185,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1170,7 +1223,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
@@ -1189,7 +1242,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial(
publish_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
@@ -1217,7 +1270,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
**set_task_status_from_call(
request,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
@@ -1227,7 +1280,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
@@ -1242,9 +1295,12 @@ def completed(call: APICall, company_id, request: CompletedRequest):
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):
def ping(call: APICall, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
task_ids=[request.task],
company_id=company_id,
user_id=call.identity.user,
last_update=datetime.utcnow(),
)
@@ -1259,7 +1315,7 @@ def add_or_update_artifacts(
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifacts=request.artifacts,
force=True,
@@ -1276,7 +1332,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
@@ -1287,14 +1343,22 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidTaskId,
enabled=True,
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidTaskId,
enabled=False,
)
@@ -1305,6 +1369,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
_assert_writable_tasks(company_id, call.identity, request.ids)
updated_projects = set(
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
)
@@ -1324,9 +1389,25 @@ def move(call: APICall, company_id: str, request: MoveRequest):
return {"project_id": project_id}
@endpoint("tasks.update_tags")
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
_assert_writable_tasks(company_id, call.identity, request.ids)
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Task,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
@@ -1336,6 +1417,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
updated = TaskBLL.update_statistics(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=request.iteration,
**({f"push__{models_field}": model} if not updated else {}),
)
@@ -1345,7 +1427,9 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
task = get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
@@ -1357,5 +1441,9 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
if names
}
updated = task.update(last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,)
updated = task.update(
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
**commands,
)
return {"updated": updated}

View File

@@ -7,7 +7,11 @@ from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest
from apiserver.apimodels.users import (
CreateRequest,
SetPreferencesRequest,
UserRequest,
)
from apiserver.bll.project import ProjectBLL
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@@ -48,13 +52,13 @@ def get_user(call, company_id, user_id, only=None):
return res.to_proper_dict()
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
@endpoint("users.get_by_id")
def get_by_id(call: APICall, company_id, request: UserRequest):
user_id = request.user
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
@endpoint("users.get_all_ex")
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(company=company_id, query_dict=call.data)
@@ -62,7 +66,7 @@ def get_all_ex(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
@endpoint("users.get_all_ex", min_version="2.8")
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
@@ -83,7 +87,7 @@ def get_all_ex2_8(call: APICall, company_id, _):
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
@endpoint("users.get_all")
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
@@ -98,9 +102,7 @@ def get_current_user(call: APICall, company_id, _):
user_id = call.identity.user
projection = (
{"company.name"}
.union(User.get_fields())
.difference(User.get_exclude_fields())
{"company.name"}.union(User.get_fields()).difference(User.get_exclude_fields())
)
res = User.get_many_with_join(
query=Q(id=user_id),
@@ -114,9 +116,13 @@ def get_current_user(call: APICall, company_id, _):
user = res[0]
user["role"] = call.identity.role
resp = {
"user": user,
"getting_started": config.get("apiserver.getting_started_info", None),
resp = dict(
user=user, getting_started=config.get("apiserver.getting_started_info", None)
)
resp["settings"] = {
"max_download_items": int(
config.get("services.organization.download.max_download_items", 1000)
)
}
call.result.data = resp
@@ -136,9 +142,9 @@ def create(call: APICall):
UserBLL.create(call.data_model)
@endpoint("users.delete", required_fields=["user"])
def delete(call: APICall):
UserBLL.delete(call.data["user"])
@endpoint("users.delete")
def delete(_: APICall, __, request: UserRequest):
UserBLL.delete(request.user)
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
@@ -157,9 +163,9 @@ def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
return User.safe_update(company_id, user_id, partial_update_dict)
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
user_id = call.data["user"]
@endpoint("users.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UserRequest):
user_id = request.user
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)

View File

@@ -101,7 +101,7 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
t for t in values if t.startswith(GetMixin.NewListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(

View File

@@ -22,6 +22,7 @@ from apiserver.apimodels.workers import (
GetActivityReportRequest,
GetActivityReportResponse,
ActivityReportSeries,
GetCountRequest,
)
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
@@ -50,6 +51,20 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
)
@endpoint(
"workers.get_count", request_data_model=GetCountRequest,
)
def get_all(call: APICall, company_id: str, request: GetCountRequest):
call.result.data = {
"count": worker_bll.get_count(
company_id,
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
)
}
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
def register(call: APICall, company_id, request: RegisterRequest):
worker = request.worker

View File

@@ -9,9 +9,8 @@ import requests
import six
from boltons.iterutils import remap
from boltons.typeutils import issubclass
from requests.adapters import HTTPAdapter
from requests.adapters import HTTPAdapter, Retry
from requests.auth import HTTPBasicAuth
from requests.packages.urllib3.util.retry import Retry
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config

View File

@@ -127,7 +127,7 @@ class TestBatchOperations(TestService):
def _temp_task(self):
return self.create_temp(
service="tasks", type="testing", name=self.name, input=dict(view={}),
service="tasks", type="testing", name=self.name,
)
def _temp_task_model(self, task, **kwargs) -> str:

Some files were not shown because too many files have changed in this diff Show More