Compare commits

96 Commits

Author SHA1 Message Date
allegroai
7b6f24b24d Version bump to 1.8.0 2022-11-29 17:50:32 +02:00
allegroai
d03a931d84 Remove buggy debug code 2022-11-29 17:50:17 +02:00
allegroai
5cc7199661 Fix clearing ES scroll 2022-11-29 17:44:31 +02:00
allegroai
6537e9ef69 Add active_users and search_hidden options to get_entities_count endpoint 2022-11-29 17:44:19 +02:00
allegroai
930aaff791 Fix test 2022-11-29 17:43:43 +02:00
allegroai
1999fb2479 Add URL async delete prefixes setting 2022-11-29 17:43:08 +02:00
allegroai
9db14cc31d Update endpoint API versions 2022-11-29 17:42:19 +02:00
allegroai
e3cc689528 Enhance async_urls_delete feature with max_async_deleted_events_per_sec setting and fileserver timeout and prefixes 2022-11-29 17:41:49 +02:00
allegroai
9e0adc77dd Make sure id field is included in ordering 2022-11-29 17:40:09 +02:00
allegroai
58d9a64537 Fix endpoint name in schema 2022-11-29 17:39:35 +02:00
allegroai
d397d2ae20 Support optional async events deletion when deleting tasks 2022-11-29 17:38:41 +02:00
allegroai
2d711e1500 Collect model event URLs during task and project cleanup 2022-11-29 17:38:03 +02:00
allegroai
97992b0d9e Support returning multiple plots during history navigation 2022-11-29 17:37:30 +02:00
allegroai
bc23f1b0cf Add "queue watched" indication for tasks.enqueue and tasks.enqueue_many 2022-11-29 17:36:41 +02:00
allegroai
6b3eff1426 Support workers system tags 2022-11-29 17:35:25 +02:00
allegroai
caaf801cd0 Support plots navigation by iteration 2022-11-29 17:34:57 +02:00
allegroai
c23e8a90d0 Support model events 2022-11-29 17:34:06 +02:00
allegroai
fa5b28ca0e Support HTTP URLs when deleting fileserver files 2022-11-29 17:33:18 +02:00
allegroai
bfb55a9463 Support deleting external artifacts when deleting projects 2022-11-29 17:32:41 +02:00
allegroai
37e485e1f2 Upgrade API version to 2.22 2022-11-29 17:29:57 +02:00
allegroai
3451ff441f Fix projects.get_all_ex with active_users filtering did not work if the project id was passed as a string and not list 2022-11-29 17:27:54 +02:00
allegroai
53c9b5525e Add preliminary support for datasets under projects 2022-11-29 17:27:02 +02:00
PSL
e5230edac3 Add mongo username and password authentication (#162)
Co-authored-by: pangshaoliang <pangshaoliang@megvii.com>
2022-10-08 15:56:02 +03:00
allegroai
a54dd8030c Add async-delete to docker-compose 2022-09-29 19:44:05 +03:00
allegroai
482a5c34bc Version bump to 1.7.0 2022-09-29 19:40:20 +03:00
allegroai
ee2a72c70f Add company/uri index 2022-09-29 19:40:05 +03:00
allegroai
a0d8aaf3b9 Fix urls are not unquoted in batch_delete 2022-09-29 19:39:02 +03:00
allegroai
de1f823213 Removed stub timing context 2022-09-29 19:37:15 +03:00
allegroai
0c9e2f92ee Add server-side support for deleting files from fileserver on task delete 2022-09-29 19:34:24 +03:00
allegroai
6c49e96ff0 Update API version to 2.21 2022-09-29 19:31:42 +03:00
allegroai
81e3fc6577 Improve utilities 2022-09-29 19:30:57 +03:00
allegroai
e6dc4b7557 Set cloned task parent to original task if original task has no parent 2022-09-29 19:30:13 +03:00
allegroai
238a47a197 Add created field to backend.user 2022-09-29 19:29:36 +03:00
allegroai
04e7076628 Add support for specifying a specific task ID in queues.get_next_task 2022-09-29 19:27:42 +03:00
allegroai
0531612bf4 Fix deleting a queue should dequeue all enqueued tasks 2022-09-29 19:27:09 +03:00
allegroai
3ae410a1e9 Remove the ThreadsManager.terminating flag 2022-09-29 19:23:26 +03:00
allegroai
98ed3075dd Added exclude support when converting mongo objects to dictionary 2022-09-29 19:21:28 +03:00
allegroai
b871bf4224 Fix projects.get_all_ex 2022-09-29 19:20:50 +03:00
allegroai
8d4c02fc3c Add support for hidden internal queues 2022-09-29 19:20:24 +03:00
allegroai
b986980c75 Use correct attrs version, remove related dependency 2022-09-29 19:18:38 +03:00
allegroai
a4fa567be2 Fix task stats update 2022-09-29 19:18:22 +03:00
allegroai
ddb91f226a Add Task Unique Metrics to task object 2022-09-29 19:16:56 +03:00
allegroai
7772f47773 Support datetime ranges in field queries 2022-09-29 19:15:50 +03:00
allegroai
9c118d14e0 Add missing last_update field to models.get_all APIs 2022-09-29 19:14:13 +03:00
allegroai
efd56e085e Fix threaded jobs management (invoke only from AppSequence) 2022-09-29 19:13:22 +03:00
allegroai
4dff163af4 Improve examples pre-population code 2022-09-29 19:11:21 +03:00
allegroai
242a78a0fe Add request logging using the CLEARML_SERVER_DEBUG_REQUESTS env var 2022-09-29 19:10:55 +03:00
allegroai
78989fea91 Add better mongodb connection string verbosity 2022-09-29 19:10:13 +03:00
allegroai
5de7c12062 Version bump to v1.6 2022-07-08 18:05:43 +03:00
allegroai
3f79c19079 Add v1.6.0 mongodb migration 2022-07-08 18:05:33 +03:00
allegroai
fe29743c54 Add support for new IDs generation when importing projects 2022-07-08 18:04:40 +03:00
allegroai
d760cf5835 Remove use of dpath in query projection 2022-07-08 18:04:02 +03:00
allegroai
3695f25a5f Fix internal error returned to clients 2022-07-08 18:03:38 +03:00
allegroai
c6f1beafdd Update API version to 2.20 2022-07-08 18:02:44 +03:00
allegroai
68a54c34f3 Add user creation time to users.get_current_user 2022-07-08 17:59:45 +03:00
allegroai
ab495ae586 Fix archived projects handling 2022-07-08 17:55:02 +03:00
allegroai
b058770af1 Fix handling of empty hyperparam/configuration keys 2022-07-08 17:54:19 +03:00
allegroai
f7e833bf6f Fix loading expired worker entries from Redis 2022-07-08 17:53:12 +03:00
allegroai
36b9ab0453 Fix handling of empty keys in query 2022-07-08 17:51:49 +03:00
allegroai
ec0436d0da Fix task cleanup 2022-07-08 17:50:49 +03:00
allegroai
0f6c4e75b7 Fix debug images URL handling and task routing 2022-07-08 17:50:26 +03:00
allegroai
a41ae112a1 Fix backward compatibility when importing old projects 2022-07-08 17:49:36 +03:00
allegroai
c28f478ea8 Fix worker Id is used instead of worker key when processing report 2022-07-08 17:48:17 +03:00
allegroai
c18eb99d06 Return getting_started_info in users.get_current_user 2022-07-08 17:45:33 +03:00
allegroai
3a60f00d93 Add support for Dataset projects 2022-07-08 17:45:03 +03:00
allegroai
ee87778548 Better support for PyJWT 2.4 2022-07-08 17:44:17 +03:00
allegroai
52c0c4d438 Add model task cleanup on tasks.reset 2022-07-08 17:43:54 +03:00
allegroai
d117a4f022 Support max_task_entries option in queues.get_by_id/get_all
Add queues.peek_task and queues.get_num_entries
2022-07-08 17:42:20 +03:00
allegroai
6683d2d7a9 Fix task cleanup 2022-07-08 17:40:55 +03:00
allegroai
05357fe25e Support publish option in tasks.completed 2022-07-08 17:40:43 +03:00
allegroai
adc1825843 Add support for model statistics 2022-07-08 17:39:41 +03:00
allegroai
0c15169668 Improve tests 2022-07-08 17:38:31 +03:00
allegroai
123dc1dcfb Improve query error handling 2022-07-08 17:38:22 +03:00
allegroai
b2feafac09 Support workers filtering with tags 2022-07-08 17:37:33 +03:00
allegroai
b41ab8c550 Better support for queue metrics and queue metrics refresh on sample 2022-07-08 17:36:46 +03:00
allegroai
62d5779bd5 Count own tasks/models for projects 2022-07-08 17:35:01 +03:00
allegroai
f8b9d9802e Add support for organization.get_entities_count 2022-07-08 17:32:56 +03:00
allegroai
dd8a1503b0 Add support for navigate_current_metric in events.get_debug_image_sample 2022-07-08 17:31:44 +03:00
allegroai
cff98ae900 Add support for events.get_task_single_value_metrics, events.plots, events.get_plot_sample and events.next_plot_sample 2022-07-08 17:29:39 +03:00
allegroai
9b108740da Bump PyJWT version due to "Key confusion through non-blocklisted public key formats" vulnerability 2022-05-25 16:50:19 +03:00
allegroai
08a7bc7c9f Fix not all the event logs are returned from sharded ES 2022-05-20 15:11:05 +03:00
allegroai
fb256d7e5b Version bump to v1.5 2022-05-18 15:29:45 +03:00
allegroai
710443b078 Fix move task to trash is not thread-safe 2022-05-18 10:31:20 +03:00
allegroai
e0cde2f7c9 Add support for deleting pipeline projects 2022-05-18 10:30:21 +03:00
allegroai
60b9c8de14 Allow arbitrary task fields in project statistics filter 2022-05-18 10:29:36 +03:00
allegroai
ecffe26be4 Fix auth.edit_credentials 2022-05-18 10:28:58 +03:00
allegroai
2570bd9e26 Fix ES issue with capital letters in index name 2022-05-18 10:18:23 +03:00
allegroai
174f84514a Fix no destination when merging projects 2022-05-18 10:17:34 +03:00
allegroai
65cb8d7b43 Refactor method name 2022-05-18 10:16:41 +03:00
allegroai
5f8ef808a3 Fix ES issue with capital letters in index name 2022-05-18 10:16:19 +03:00
allegroai
4941ac70e0 Add events.clear_task_log 2022-05-17 16:09:23 +03:00
allegroai
67cd461145 Add auth.edit_credentials 2022-05-17 16:08:12 +03:00
allegroai
92b5fc6f9a Fix handling hidden sub-projects 2022-05-17 16:06:34 +03:00
allegroai
b90165b4e4 Support queue_name in tasks enqueue 2022-05-17 16:04:34 +03:00
allegroai
6c2dcb5c8a Improve error message 2022-05-17 15:56:18 +03:00
allegroai
3efed32934 Add X-Jwt-Payload to redacted headers 2022-05-17 15:55:41 +03:00
111 changed files with 6492 additions and 2567 deletions

View File

@@ -26,6 +26,9 @@
23: ["invalid_domain_name", "malformed domain name"]
24: ["not_public_object", "object is not public"]
# Auth / Login
75: ["invalid_access_key", "access key not found for user"]
# Tasks
100: ["task_error", "general task error"]
101: ["invalid_task_id", "invalid task id"]
@@ -86,7 +89,7 @@
# Database
800: ["data_validation_error", "data validation error"]
801: ["expected_unique_data", "value combination already exists"]
801: ["expected_unique_data", "value combination already exists (unique field already contains this value)"]
# Workers
1001: ["invalid_worker_id", "invalid worker id"]

View File

@@ -61,12 +61,6 @@ class ListField(fields.ListField):
item.validate()
# since there is no distinction between None and empty DictField
# this value can be used as sentinel in order to distinguish
# between not set and empty DictField
DictFieldNotSet = {}
class DictField(fields.BaseField):
types = (dict,)

View File

@@ -96,6 +96,11 @@ class GetCredentialsResponse(Base):
credentials = ListField(CredentialsResponse)
class EditCredentialsRequest(Base):
access_key = StringField(required=True)
label = StringField()
class RevokeCredentialsRequest(Base):
access_key = StringField(required=True)

View File

@@ -26,6 +26,7 @@ class MetricVariants(Base):
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@@ -40,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
)
],
)
model_events: bool = BoolField(default=False)
class TaskMetric(Base):
@@ -48,7 +50,7 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
@@ -56,24 +58,36 @@ class DebugImagesRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
class TaskMetricVariant(Base):
class GetVariantSampleRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class NextDebugImageSampleRequest(Base):
class GetMetricSamplesRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class NextHistorySampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
next_iteration: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class LogOrderEnum(StringEnum):
@@ -92,6 +106,7 @@ class TaskEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class LogEventsRequest(TaskEventsRequestBase):
@@ -107,6 +122,7 @@ class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class IterationEvents(Base):
@@ -119,15 +135,23 @@ class MetricEvents(Base):
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
class MetricEventsResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class TaskMetricsRequest(Base):
class MultiTasksRequestBase(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
model_events: bool = BoolField(default=False)
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
@@ -137,7 +161,14 @@ class TaskPlotsRequest(Base):
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):
scroll_id: str = StringField()
class ClearTaskLogRequest(Base):
task: str = StringField(required=True)
threshold_sec = IntField()
allow_locked = BoolField(default=False)

View File

@@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)

View File

@@ -1,5 +1,7 @@
from jsonmodels import fields, models
from apiserver.apimodels import DictField
class Filter(models.Base):
tags = fields.ListField([str])
@@ -9,3 +11,13 @@ class Filter(models.Base):
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)
class EntitiesCountRequest(models.Base):
projects = DictField()
tasks = DictField()
models = DictField()
pipelines = DictField()
datasets = DictField()
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)

View File

@@ -57,6 +57,7 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)

View File

@@ -26,9 +26,19 @@ class QueueRequest(Base):
queue = StringField(required=True)
class GetByIdRequest(QueueRequest):
max_task_entries = IntField()
class GetAllRequest(Base):
max_task_entries = IntField()
search_hidden = BoolField(default=False)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
task = StringField()
class DeleteRequest(QueueRequest):
@@ -59,6 +69,7 @@ class GetMetricsRequest(Base):
from_date = FloatField(required=True, validators=validators.Min(0))
to_date = FloatField(required=True, validators=validators.Min(0))
interval = IntField(required=True, validators=validators.Min(1))
refresh = BoolField(default=False)
class QueueMetrics(Base):

View File

@@ -42,6 +42,7 @@ class StartedResponse(UpdateResponse):
class EnqueueResponse(UpdateResponse):
queued = IntField()
queue_watched = BoolField()
class EnqueueBatchItem(UpdateBatchItem):
@@ -50,6 +51,7 @@ class EnqueueBatchItem(UpdateBatchItem):
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
queue_watched = BoolField()
class DequeueResponse(UpdateResponse):
@@ -96,18 +98,29 @@ class UpdateRequest(TaskUpdateRequest):
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
class SetRequirementsRequest(TaskRequest):
requirements = DictField(required=True)
class CompletedRequest(UpdateRequest):
publish = BoolField(default=False)
class CompletedResponse(UpdateResponse):
published = IntField(default=0)
class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
@@ -171,6 +184,7 @@ class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
class MultiTaskRequest(models.Base):
@@ -262,7 +276,9 @@ class StopManyRequest(TaskBatchRequest):
class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
validate_tasks = BoolField(default=False)
verify_watched_queue = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):
@@ -270,6 +286,7 @@ class DeleteManyRequest(TaskBatchRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
class ResetManyRequest(TaskBatchRequest):
@@ -277,6 +294,7 @@ class ResetManyRequest(TaskBatchRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
class PublishManyRequest(TaskBatchRequest):

View File

@@ -20,6 +20,7 @@ DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
system_tags = ListField(str)
class RegisterRequest(WorkerRequest):
@@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
tags = ListField(str)
system_tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):
@@ -96,6 +98,8 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
class GetAllResponse(Base):

View File

@@ -64,7 +64,7 @@ class AuthBLL:
feature_set="basic",
)
return GetTokenResponse(token=token.decode("ascii"))
return GetTokenResponse(token=token)
@staticmethod
def create_user(request: CreateUserRequest, call: APICall = None) -> str:

View File

@@ -1,375 +0,0 @@
import operator
from typing import Sequence, Tuple, Optional
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugSampleHistoryState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugSampleHistoryResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class DebugSampleHistory:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_debug_image(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> DebugSampleHistoryResult:
"""
Get the debug image for next/prev variant on the current iteration
If does not exist then try getting image for the first/last variant from next/prev iteration
"""
res = DebugSampleHistoryResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
image = self._get_next_for_current_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if not image:
return res
self._fill_res_and_update_state(image=image, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
):
state.variant = image["variant"]
state.iteration = image["iter"]
res.event = image
var_state = first(s for s in state.variant_states if s.name == state.variant)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or image is found
"""
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in state.variant_states
if cmp(var_state.name, state.variant)
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"terms": {"variant": [v.name for v in variants]}},
{"term": {"iter": state.iteration}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": {"variant": "desc" if navigate_earlier else "asc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the image falls in invalid range are discarded
If no suitable image is found then None is returned
"""
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"exists": {"field": "url"}},
]
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in state.variant_states
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = state.variant_states
if not variants:
return
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
must_conditions.append({"bool": {"should": variants_conditions}})
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
es_req = {
"size": 1,
"sort": [{"iter": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_debug_image_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
) -> DebugSampleHistoryResult:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(s for s in state.variant_states if s.name == variant)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
image=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
company_id=company_id, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
for var_name, min_iter, max_iter in variant_iterations
]
def _get_variant_iterations(
self,
company_id: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
) -> Sequence[Tuple[str, int, int]]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
must = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"exists": {"field": "url"}},
]
if variants:
must.append({"terms": {"variant": variants}})
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": {
"variants": {
# all variants that sent debug images
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(
es_res, ("aggregations", "variants", "buckets")
)
]

View File

@@ -13,32 +13,35 @@ from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
get_index_name,
check_empty_data,
search_company_events,
delete_company_events,
MetricVariants,
get_metric_variants_condition,
uncompress_plot,
get_max_metric_and_variant_counts,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
# noinspection PyTypeChecker
@@ -66,16 +69,31 @@ class EventBLL(object):
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
flags=re.IGNORECASE,
)
_task_event_query = {
"bool": {
"should": [
{"term": {"model_event": False}},
{"bool": {"must_not": [{"exists": {"field": "model_event"}}]}},
]
}
}
_model_event_query = {"term": {"model_event": True}}
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
self.redis = redis or redman.connection("apiserver")
self._metrics = EventMetrics(self.es)
self._skip_iteration_for_metric = set(
config.get("services.events.ignore_iteration.metrics", [])
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.debug_images_iterator = MetricDebugImagesIterator(
es=self.es, redis=self.redis
)
self.debug_image_sample_history = HistoryDebugImageIterator(
es=self.es, redis=self.redis
)
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
self.events_iterator = EventsIterator(es=self.es)
@property
@@ -88,18 +106,42 @@ class EventBLL(object):
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
with translate_errors_context():
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
@staticmethod
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
"""Verify that task exists and can be updated"""
if not model_ids:
return set()
with translate_errors_context():
query = Q(id__in=model_ids, company=company_id)
if not allow_locked_models:
query &= Q(ready__ne=True)
res = Model.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
self, company_id, events, worker, allow_locked=False
) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False)
for event in events:
if event.get("model_event", model_events) != model_events:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events"
)
if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
)
actions: List[dict] = []
task_ids = set()
task_or_model_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
3, dict
@@ -109,13 +151,28 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked_tasks,
)
if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked,
)
entity_name = "task"
for event in events:
# remove spaces from event type
@@ -129,13 +186,17 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
task_id = event.get("task")
if task_id is None:
if model_events and event_type == EventType.task_log.value:
errors_per_type[f"Task log events are not supported for models"] += 1
continue
task_or_model_id = event.get("task")
if task_or_model_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
if task_or_model_id not in valid_entities:
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
continue
event["type"] = event_type
@@ -157,10 +218,13 @@ class EventBLL(object):
# force iter to be a long int
iter = event.get("iter")
if iter is not None:
iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1
continue
if model_events:
iter = 0
else:
iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1
continue
event["iter"] = iter
# used to have "values" to indicate array. no need anymore
@@ -170,6 +234,7 @@ class EventBLL(object):
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type)
es_action = {
@@ -184,21 +249,26 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
task_ids.add(task_id)
task_or_model_ids.add(task_or_model_id)
if (
iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
task_iteration[task_or_model_id] = max(
iter, task_iteration[task_or_model_id]
)
if not model_events:
self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
)
actions.append(es_action)
plot_actions = [
@@ -219,39 +289,41 @@ class EventBLL(object):
with translate_errors_context():
if actions:
chunk_size = 500
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
if not model_events:
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
task_id=task_or_model_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
iter_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
last_events=task_last_events.get(task_or_model_id),
)
if not updated:
remaining_tasks.add(task_id)
remaining_tasks.add(task_or_model_id)
continue
if remaining_tasks:
@@ -307,11 +379,7 @@ class EventBLL(object):
@parallel_chunked_decorator(chunk_size=10)
def uncompress_plots(self, plot_events: Sequence[dict]):
for event in plot_events:
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()
uncompress_plot(event)
@staticmethod
def _is_valid_json(text: str) -> bool:
@@ -391,33 +459,17 @@ class EventBLL(object):
as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
update time.
"""
fields = {}
if iter_max is not None:
fields["last_iteration_max"] = iter_max
if last_scalar_events:
fields["last_scalar_values"] = list(
flatten_nested_items(
last_scalar_events,
nesting=2,
include_leaves=[
"value",
"min_value",
"max_value",
"metric",
"variant",
],
)
)
if last_events:
fields["last_events"] = last_events
if not fields:
if iter_max is None and not last_events and not last_scalar_events:
return False
return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
return TaskBLL.update_statistics(
task_id,
company_id,
last_update=now,
last_iteration_max=iter_max,
last_scalar_events=last_scalar_events,
last_events=last_events,
)
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.id_fields if field in event)
@@ -436,7 +488,7 @@ class EventBLL(object):
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "task_log_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
size = min(batch_size, 10000)
@@ -449,7 +501,7 @@ class EventBLL(object):
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@@ -479,6 +531,11 @@ class EventBLL(object):
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // num_last_iterations)
es_req: dict = {
"size": 0,
@@ -486,14 +543,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -512,12 +569,8 @@ class EventBLL(object):
"query": query,
}
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
@@ -539,12 +592,13 @@ class EventBLL(object):
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
model_events: bool = False,
):
if scroll_id == self.empty_scroll:
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = EventType.metrics_plot
@@ -565,7 +619,7 @@ class EventBLL(object):
}
must = [plot_valid_condition]
if last_iterations_per_plot is None:
if last_iterations_per_plot is None or model_events:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
@@ -608,7 +662,7 @@ class EventBLL(object):
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_plots"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@@ -633,11 +687,47 @@ class EventBLL(object):
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.es.clear_scroll(scroll_id=next_scroll_id)
self.clear_scroll(next_scroll_id)
next_scroll_id = self.empty_scroll
return events, total_events, next_scroll_id
def get_debug_image_urls(
self, company_id: str, task_id: str, after_key: dict = None
) -> Tuple[Sequence[str], Optional[dict]]:
if check_empty_data(self.es, company_id, EventType.metrics_image):
return [], None
es_req = {
"size": 0,
"aggs": {
"debug_images": {
"composite": {
"size": 1000,
**({"after": after_key} if after_key else {}),
"sources": [{"url": {"terms": {"field": "url"}}}],
}
}
},
"query": {
"bool": {
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
}
},
}
es_response = search_company_events(
self.es,
company_id=company_id,
event_type=EventType.metrics_image,
body=es_req,
)
res = nested_get(es_response, ("aggregations", "debug_images"))
if not res:
return [], None
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
@@ -685,12 +775,13 @@ class EventBLL(object):
size=500,
scroll_id=None,
no_scroll=False,
model_events=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
@@ -704,7 +795,7 @@ class EventBLL(object):
if variant:
must.append({"term": {"variant": variant}})
if last_iter_count is None:
if last_iter_count is None or model_events:
must.append({"terms": {"task": task_ids}})
else:
tasks_iters = self.get_last_iters(
@@ -738,7 +829,7 @@ class EventBLL(object):
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@@ -763,20 +854,24 @@ class EventBLL(object):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
es_req = {
"size": 0,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
}
}
@@ -786,12 +881,8 @@ class EventBLL(object):
"query": query,
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@@ -817,6 +908,10 @@ class EventBLL(object):
]
}
}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
es_req = {
"size": 0,
"query": query,
@@ -824,14 +919,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -859,12 +954,8 @@ class EventBLL(object):
},
"_source": {"excludes": []},
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
metrics = []
max_timestamp = 0
@@ -909,7 +1000,7 @@ class EventBLL(object):
"_source": ["iter", "value"],
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
@@ -952,7 +1043,7 @@ class EventBLL(object):
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
@@ -965,49 +1056,133 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
@staticmethod
def _validate_model_state(
company_id: str, model_id: str, allow_locked: bool = False
):
extra_msg = None
query = Q(id=model_id, company=company_id)
if not allow_locked:
query &= Q(ready__ne=True)
extra_msg = "or model published"
res = Model.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidModelId(
extra_msg, company=company_id, id=model_id
)
@staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
@staticmethod
def _get_events_deletion_params(async_delete: bool) -> dict:
if async_delete:
return {
"wait_for_completion": False,
"requests_per_second": config.get(
"services.events.max_async_deleted_events_per_sec", 1000
),
}
return {"refresh": True}
def delete_task_events(
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
):
if model:
self._validate_model_state(
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
)
else:
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
**self._get_events_deletion_params(async_delete),
)
return es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
def clear_task_log(
self,
company_id: str,
task_id: str,
allow_locked: bool = False,
threshold_sec: int = None,
):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.task_log
):
return 0
with translate_errors_context():
must = [{"term": {"task": task_id}}]
sort = None
if threshold_sec:
timestamp_ms = int(threshold_sec * 1000)
must.append(
{
"range": {
"timestamp": {
"lt": (es_factory.get_timestamp_millis() - timestamp_ms)
}
}
}
)
sort = {"timestamp": {"order": "desc"}}
es_req = {
"query": {"bool": {"must": must}},
**({"sort": sort} if sort else {}),
}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.task_log,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0)
def delete_multi_task_events(
self, company_id: str, task_ids: Sequence[str], async_delete=False
):
"""
Delete mutliple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
**self._get_events_deletion_params(async_delete),
)
return es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:

View File

@@ -1,10 +1,14 @@
import base64
import zlib
from enum import Enum
from typing import Union, Sequence, Mapping
from typing import Union, Sequence, Mapping, Tuple
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
class EventType(Enum):
@@ -16,10 +20,13 @@ class EventType(Enum):
all = "*"
SINGLE_SCALAR_ITERATION = -(2 ** 31)
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings:
_max_es_allowed_aggregation_buckets = 10000
@classproperty
def max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@@ -31,17 +38,23 @@ class EventSettings:
)
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
def max_es_buckets(self):
percentage = (
min(
100,
config.get(
"services.events.events_retrieval.dynamic_metrics_count_threshold",
80,
),
)
/ 100
)
return int(self._max_es_allowed_aggregation_buckets * percentage)
def get_index_name(company_id: str, event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
return f"events-{event_type}-{company_id.lower()}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
@@ -66,9 +79,7 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(
index=es_index, body=body, conflicts="proceed", **kwargs
)
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
def count_company_events(
@@ -78,6 +89,44 @@ def count_company_events(
return es.count(index=es_index, body=body, **kwargs)
def get_max_metric_and_variant_counts(
es: Elasticsearch,
company_id: Union[str, Sequence[str]],
event_type: EventType,
query: dict,
**kwargs,
) -> Tuple[int, int]:
dynamic = config.get(
"services.events.events_retrieval.dynamic_metrics_count", False
)
max_metrics_count = config.get(
"services.events.events_retrieval.max_metrics_count", 100
)
max_variants_count = config.get(
"services.events.events_retrieval.max_variants_count", 100
)
if not dynamic:
return max_metrics_count, max_variants_count
es_req: dict = {
"size": 0,
"query": query,
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
}
with translate_errors_context():
es_res = search_company_events(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)
metrics_count = safe_get(
es_res, "aggregations/metrics_count/value", max_metrics_count
)
if not metrics_count:
return max_metrics_count, max_variants_count
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
conditions = [
{
@@ -94,3 +143,19 @@ def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
]
return {"bool": {"should": conditions}}
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
def uncompress_plot(event: dict):
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()

View File

@@ -4,12 +4,11 @@ from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Mapping
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
@@ -17,12 +16,13 @@ from apiserver.bll.event.event_common import (
check_empty_data,
MetricVariants,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
log = config.logger(__file__)
@@ -109,40 +109,19 @@ class EventMetrics:
def compare_scalar_metrics_average_per_iter(
self,
company_id,
task_ids: Sequence[str],
tasks: Sequence[Task],
samples,
key: ScalarKeyEnum,
allow_public=True,
):
"""
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name", "company", "company_origin"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
event_type = EventType.metrics_scalar
company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
task_name_by_id = {t.id: t.name for t in tasks}
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
company_id=company_id,
@@ -151,6 +130,7 @@ class EventMetrics:
key=ScalarKey.resolve(key),
run_parallel=False,
)
task_ids = [t.id for t in tasks]
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
@@ -166,6 +146,57 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, company_id: str, tasks: Sequence[Task]
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
):
return {}
task_ids = [t.id for t in tasks]
task_events = self._get_task_single_value_metrics(company_id, task_ids)
def _get_value(event: dict):
return {
field: event.get(field)
for field in ("metric", "variant", "value", "timestamp")
}
return {
task: [_get_value(e) for e in events]
for task, events in bucketize(task_events, itemgetter("task")).items()
}
def _get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
) -> Sequence[dict]:
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
}
with translate_errors_context():
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=EventType.metrics_scalar,
)
if not es_res["hits"]["total"]["value"]:
return []
return [hit["_source"] for hit in es_res["hits"]["hits"]]
MetricInterval = Tuple[str, str, int, int]
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@@ -219,11 +250,15 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = [{"term": {"task": task_id}}]
must = self._task_conditions(task_id)
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
@@ -231,14 +266,14 @@ class EventMetrics:
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
@@ -252,10 +287,7 @@ class EventMetrics:
},
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
@@ -307,33 +339,40 @@ class EventMetrics:
"""
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
}
aggs_result = self._query_aggregation_for_task_metrics(
company_id=company_id,
event_type=event_type,
aggs=aggs,
task_id=task_id,
metrics=metrics,
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
},
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
@@ -360,19 +399,18 @@ class EventMetrics:
for key, value in aggregation.items()
}
def _query_aggregation_for_task_metrics(
self,
company_id: str,
event_type: EventType,
aggs: dict,
task_id: str,
metrics: Sequence[Tuple[str, str]],
) -> dict:
"""
Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics
"""
must = [{"term": {"task": task_id}}]
@staticmethod
def _task_conditions(task_id: str) -> list:
return [
{"term": {"task": task_id}},
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
]
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
should = [
{
@@ -387,20 +425,9 @@ class EventMetrics:
]
must.append({"bool": {"should": should}})
es_req = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": aggs,
}
return {"bool": {"must": must}}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_res.get("aggregations")
def get_tasks_metrics(
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
@@ -426,22 +453,21 @@ class EventMetrics:
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": self._task_conditions(task_id)}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": EventSettings.max_es_buckets,
"order": {"_key": "asc"},
}
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
return [
metric["key"]

View File

@@ -4,6 +4,7 @@ import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from jwt.algorithms import get_default_algorithms
from apiserver.bll.event.event_common import (
check_empty_data,
@@ -16,7 +17,6 @@ from apiserver.bll.event.event_common import (
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
@@ -75,13 +75,9 @@ class EventsIterator:
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
with translate_errors_context():
es_result = count_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_result["count"]
@@ -116,13 +112,9 @@ class EventsIterator:
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
@@ -142,11 +134,7 @@ class EventsIterator:
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
@@ -191,7 +179,7 @@ class Scroll(jsonmodels.models.Base):
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
).decode()
)
@classmethod
def from_scroll_id(cls, scroll_id: str):
@@ -202,6 +190,7 @@ class Scroll(jsonmodels.models.Base):
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
algorithms=get_default_algorithms(),
)
)
except jwt.PyJWTError:

View File

@@ -0,0 +1,455 @@
import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, BoolField, ListField
from jsonmodels.models import Base
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import (
EventType,
EventSettings,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugImageSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class VariantSampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class HistoryDebugImageIterator:
event_type = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> VariantSampleResult:
"""
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = VariantSampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if next_iteration:
event = self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event:
return res
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
@staticmethod
def _fill_res_and_update_state(
event: dict, res: VariantSampleResult, state: DebugImageSampleState
):
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@staticmethod
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in metric_variants
]
return {"bool": {"should": variants_conditions}}
metrics_conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
_get_variants_conditions(metric_variants),
]
}
}
for metric, metric_variants in metrics.items()
]
return {"bool": {"should": metrics_conditions}}
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"iter": state.iteration}},
self._get_metric_conditions(variants),
{"exists": {"field": "url"}},
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = variants
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_conditions(variants),
{"range": {"iter": {range_operator: state.iteration}}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_sample_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> VariantSampleResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = VariantSampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: DebugImageSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugImageSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugImageSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
{"exists": {"field": "url"}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
},
}
es_res = search_company_events(body=es_req, **search_args)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}

View File

@@ -0,0 +1,316 @@
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, ListField, BoolField
from jsonmodels.models import Base
from redis.client import StrictRedis
from .event_common import (
EventType,
uncompress_plot,
EventSettings,
check_empty_data,
search_company_events,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.utilities.dicts import nested_get
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class MetricState(Base):
name: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class PlotsSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
task: str = StringField()
metric: str = StringField()
metric_states: Sequence[MetricState] = ListField([MetricState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class MetricSamplesResult(object):
scroll_id: str = None
events: list = []
min_iteration: int = None
max_iteration: int = None
class HistoryPlotsIterator:
event_type = EventType.metrics_plot
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=PlotsSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> MetricSamplesResult:
"""
Get the samples for next/prev metric on the current iteration
If does not exist then try getting sample for the first/last metric from next/prev iteration
"""
res = MetricSamplesResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if navigate_earlier:
range_operator = "lt"
order = "desc"
else:
range_operator = "gt"
order = "asc"
must_conditions = [
{"term": {"task": state.task}},
]
if state.navigate_current_metric:
must_conditions.append({"term": {"metric": state.metric}})
next_iteration_condition = {
"range": {"iter": {range_operator: state.iteration}}
}
if next_iteration or state.navigate_current_metric:
must_conditions.append(next_iteration_condition)
else:
next_metric_condition = {
"bool": {
"must": [
{"term": {"iter": state.iteration}},
{"range": {"metric": {range_operator: state.metric}}},
]
}
}
must_conditions.append(
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
)
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order=order,
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def get_samples_for_metric(
self,
company_id: str,
task: str,
metric: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> MetricSamplesResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = MetricSamplesResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: PlotsSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_metric_states(company_id=company_id, state=state_)
def validate_state(state_: PlotsSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_metric_states(company_id=company_id, state=state_)
state: PlotsSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
if not metric_state:
return res
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
]
if iteration is not None:
must_conditions.append({"range": {"iter": {"lte": iteration}}})
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order="desc",
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
return res
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
metrics = self._get_metric_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.metric_states = [
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
for metric, (min_iter, max_iter) in metrics.items()
]
def _get_metric_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[int, int]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 5000,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
},
}
},
}
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=self.event_type,
)
return {
metric_bucket["key"]: (
int(metric_bucket["first_iter"]["value"]),
int(metric_bucket["last_iter"]["value"]),
)
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}
@staticmethod
def _fill_res_and_update_state(
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
):
for event in events:
uncompress_plot(event)
state.metric = events[0]["metric"]
state.iteration = events[0]["iter"]
res.events = events
metric_state = first(
ms for ms in state.metric_states if ms.name == state.metric
)
if metric_state:
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
def _get_metric_events_for_condition(
self, company_id: str, task: str, order: str, must_conditions: Sequence
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 1,
"order": {"_key": order},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"variant": {"order": "asc"}},
"size": 100,
}
}
},
},
},
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return []
for level in ("iters", "metrics"):
level_data = aggs_result[level]["buckets"]
if not level_data:
return []
aggs_result = level_data[0]
return [
hit["_source"]
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
]

View File

@@ -0,0 +1,53 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .metric_events_iterator import MetricEventsIterator, VariantState
class MetricDebugImagesIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
aggs = {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
}
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
"""If the image urls get recycled then fill the last_invalid_iteration field"""
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
if len(iters) > 1:
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
return aggs, fill_variant_state_data
def _process_event(self, event: dict) -> dict:
return event
def _get_same_variant_events_order(self) -> dict:
return {"url": {"order": "desc"}}

View File

@@ -1,8 +1,9 @@
import abc
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping
from typing import Sequence, Tuple, Optional, Mapping, Callable
import attr
import dpath
@@ -19,12 +20,13 @@ from apiserver.bll.event.event_common import (
search_company_events,
EventType,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantState(Base):
@@ -49,25 +51,24 @@ class TaskScrollState(Base):
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
class MetricEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
class MetricEventsResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
class MetricEventsIterator:
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
self.es = es
self.event_type = event_type
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
state_class=MetricEventsScrollState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
@@ -80,14 +81,14 @@ class DebugImagesIterator:
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return DebugImagesResult()
) -> MetricEventsResult:
if check_empty_data(self.es, company_id, self.event_type):
return MetricEventsResult()
def init_state(state_: DebugImageEventsScrollState):
def init_state(state_: MetricEventsScrollState):
state_.tasks = self._init_task_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
def validate_state(state_: MetricEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
@@ -99,7 +100,13 @@ class DebugImagesIterator:
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
res = MetricEventsResult(next_scroll_id=state.id)
specific_variants_requested = any(
variants
for t, metrics in task_metrics.items()
if metrics
for m, variants in metrics.items()
)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
@@ -108,6 +115,7 @@ class DebugImagesIterator:
company_id=company_id,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
specific_variants_requested=specific_variants_requested,
),
state.tasks,
)
@@ -118,11 +126,11 @@ class DebugImagesIterator:
def _reinit_outdated_task_states(
self,
company_id,
state: DebugImageEventsScrollState,
state: MetricEventsScrollState,
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new debug image events were added
Determine the metrics for which new event_type events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
@@ -132,7 +140,7 @@ class DebugImagesIterator:
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return {}
@@ -140,10 +148,10 @@ class DebugImagesIterator:
requested_metrics = task_metrics[task.id]
return {
stats.metric: stats.event_stats_by_type[
self.EVENT_TYPE.value
self.event_type.value
].last_update
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
if self.event_type.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
}
@@ -213,18 +221,37 @@ class DebugImagesIterator:
for task, metric_states in zip(task_metrics, task_metric_states)
]
@abc.abstractmethod
def _get_extra_conditions(self) -> Sequence[dict]:
pass
@abc.abstractmethod
def _get_variant_state_aggs(
self,
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
pass
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
for the variants that reported any event_type events
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
if metrics:
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
es_req: dict = {
"size": 0,
"query": query,
@@ -232,7 +259,7 @@ class DebugImagesIterator:
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
@@ -240,52 +267,33 @@ class DebugImagesIterator:
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
**(
{"aggs": variant_state_aggs}
if variant_state_aggs
else {}
),
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
def init_variant_state(variant: dict):
"""
Return new variant state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantState(variant=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
if fill_variant_state_data:
fill_variant_state_data(variant, state)
return state
return [
@@ -300,12 +308,21 @@ class DebugImagesIterator:
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
@abc.abstractmethod
def _process_event(self, event: dict) -> dict:
pass
@abc.abstractmethod
def _get_same_variant_events_order(self) -> dict:
pass
def _get_task_metric_events(
self,
task_state: TaskScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
specific_variants_requested: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
@@ -321,7 +338,7 @@ class DebugImagesIterator:
must_conditions = [
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
{"exists": {"field": "url"}},
*self._get_extra_conditions(),
]
range_condition = None
@@ -332,6 +349,8 @@ class DebugImagesIterator:
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
metrics_count = len(task_state.metrics)
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
@@ -346,20 +365,20 @@ class DebugImagesIterator:
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"size": metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"url": {"order": "desc"}}
"sort": self._get_same_variant_events_order()
}
}
},
@@ -370,9 +389,9 @@ class DebugImagesIterator:
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
)
if "aggregations" not in es_res:
return task_state.task, []
@@ -382,18 +401,26 @@ class DebugImagesIterator:
for m in task_state.metrics
for v in m.variants
}
allow_uninitialized = (
False
if specific_variants_requested
else config.get(
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
False,
)
)
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return False
return allow_uninitialized
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
ev["_source"]
self._process_event(ev["_source"])
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")

View File

@@ -0,0 +1,25 @@
from typing import Sequence
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from .event_common import EventType, uncompress_plot
from .metric_events_iterator import MetricEventsIterator
class MetricPlotsIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_plot)
def _get_extra_conditions(self) -> Sequence[dict]:
return []
def _get_variant_state_aggs(self):
return None, None
def _process_event(self, event: dict) -> dict:
uncompress_plot(event)
return event
def _get_same_variant_events_order(self) -> dict:
return {"timestamp": {"order": "desc"}}

View File

@@ -1,5 +1,7 @@
from datetime import datetime
from typing import Callable, Tuple
from typing import Callable, Tuple, Sequence, Dict, Optional
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
@@ -24,6 +26,33 @@ class ModelBLL:
raise errors.bad_request.InvalidModelId(**query)
return model
@staticmethod
def assert_exists(
company_id,
model_ids,
only=None,
allow_public=False,
return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
query = Q(id__in=ids)
q = Model.get_many(
company=company_id,
query=query,
allow_public=allow_public,
return_dicts=False,
)
if only:
q = q.only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidModelId(ids=model_ids)
if return_models:
return list(q)
@classmethod
def publish_model(
cls,
@@ -128,3 +157,33 @@ class ModelBLL:
)
return unarchived
@classmethod
def get_model_stats(
cls, company: str, model_ids: Sequence[str],
) -> Dict[str, dict]:
if not model_ids:
return {}
result = Model.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": {"$in": model_ids},
}
},
{
"$addFields": {
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
]
)
return {
r.pop("_id"): r
for r in result
}

View File

@@ -11,7 +11,6 @@ from apiserver.utilities.parameter_key_escaper import (
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@@ -42,27 +41,25 @@ class Metadata:
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):

View File

@@ -17,6 +17,7 @@ from typing import (
Any,
)
from boltons.iterutils import partition
from mongoengine import Q, Document
from apiserver import database
@@ -28,7 +29,6 @@ from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
from apiserver.database.utils import get_options, get_company_or_none_constraint
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from .sub_projects import (
_reposition_project_with_children,
@@ -56,50 +56,53 @@ class ProjectBLL:
Remove the source project
Return the amounts of moved entities and subprojects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
source=source_id
)
source = Project.get(company, source_id)
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
source=source_id
)
source = Project.get(company, source_id)
if destination_id:
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
else:
destination = None
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
if destination:
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0
for entity_type in (Task, Model):
moved_entities += entity_type.objects(
company=company,
project=source_id,
system_tags__nin=[EntityVisibility.archived.value],
).update(upsert=False, project=destination_id)
moved_entities = 0
for entity_type in (Task, Model):
moved_entities += entity_type.objects(
company=company,
project=source_id,
system_tags__nin=[EntityVisibility.archived.value],
).update(upsert=False, project=destination_id)
moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1
moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1
affected = {source.id, *(source.path or [])}
source.delete()
affected = {source.id, *(source.path or [])}
source.delete()
if destination:
destination.update(last_update=datetime.utcnow())
affected.update({destination.id, *(destination.path or [])})
if destination:
destination.update(last_update=datetime.utcnow())
affected.update({destination.id, *(destination.path or [])})
return moved_entities, moved_sub_projects, affected
@@ -122,79 +125,76 @@ class ProjectBLL:
it should be writable. The source location should be writable too.
Return the number of moved projects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
Project.get_for_writing(company=project.company, id=old_parent_id)
if old_parent_id
else None
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
Project.get_for_writing(company=project.company, id=old_parent_id)
if old_parent_id
else None
)
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
)
new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
if new_parent and (
project_id == new_parent.id or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if (
new_parent
and project_id == new_parent.id
or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
now = datetime.utcnow()
affected = set()
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
now = datetime.utcnow()
affected = set()
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
return moved, affected
return moved, affected
@classmethod
def update(cls, company: str, project_id: str, **fields):
with TimingContext("mongo", "projects_update"):
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
new_name = fields.pop("name", None)
if new_name:
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
new_name = fields.pop("name", None)
if new_name:
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
fields["basename"] = new_name.split("/")[-1]
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects(
[project.id], _only=("id", "name", "path")
)[project.id]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
return updated
return updated
@classmethod
def create(
@@ -222,6 +222,7 @@ class ProjectBLL:
user=user,
company=company,
name=name,
basename=name.split("/")[-1],
description=description,
tags=tags,
system_tags=system_tags,
@@ -296,24 +297,23 @@ class ProjectBLL:
"""
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
"""
with TimingContext("mongo", "move_under_project"):
project = cls.find_or_create(
user=user,
company=company,
project_id=project,
project_name=project_name,
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
project = cls.find_or_create(
user=user,
company=company,
project_id=project,
project_name=project_name,
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@@ -325,6 +325,7 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
@@ -349,7 +350,10 @@ class ProjectBLL:
# count tasks per project per status
{
"$match": cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
company=company_id,
project_ids=project_ids,
filter_=filter_,
users=users,
)
},
ensure_valid_fields(),
@@ -454,22 +458,39 @@ class ProjectBLL:
)
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
def get_state_filter() -> dict:
def add_state_to_filter(f: Mapping[str, Any]) -> Mapping[str, Any]:
if not specific_state:
return {}
return f
f = f or {}
new_f = {k: v for k, v in f.items() if k != "system_tags"}
system_tags = [
tag
for tag in f.get("system_tags", [])
if tag
not in (
EntityVisibility.archived.value,
f"-{EntityVisibility.archived.value}",
)
]
if specific_state == EntityVisibility.archived:
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
system_tags.append(EntityVisibility.archived.value)
else:
system_tags.append(f"-{EntityVisibility.archived.value}")
new_f["system_tags"] = system_tags
return new_f
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
**cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(),
}
"$match": cls.get_match_conditions(
company=company_id,
project_ids=project_ids,
filter_=add_state_to_filter(filter_),
users=users,
)
},
ensure_valid_fields(),
{
@@ -504,6 +525,40 @@ class ProjectBLL:
aggregated[pid] = reduce(func, relevant_data)
return aggregated
@classmethod
def get_dataset_stats(
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
task_runtime_pipeline = [
{
"$match": {
**cls.get_match_conditions(
company=company,
project_ids=project_ids,
users=users,
filter_={
"system_tags": [f"-{EntityVisibility.archived.value}"]
},
),
"runtime": {"$exists": True, "$gt": {}},
}
},
{"$project": {"project": 1, "runtime": 1, "last_update": 1}},
{"$sort": {"project": 1, "last_update": 1}},
{"$group": {"_id": "$project", "runtime": {"$last": "$runtime"}}},
]
return {
r["_id"]: {
"file_count": r["runtime"].get("ds_file_count", 0),
"total_size": r["runtime"].get("ds_total_size", 0),
}
for r in Task.aggregate(task_runtime_pipeline)
}
@classmethod
def get_project_stats(
cls,
@@ -511,14 +566,21 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
return_hidden_children: bool = False,
search_hidden: bool = False,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
user_active_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name", "system_tags"))
_get_sub_projects(
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=user_active_project_ids,
)
if include_children
else {}
)
@@ -530,6 +592,7 @@ class ProjectBLL:
project_ids=list(project_ids_with_children),
specific_state=specific_state,
filter_=filter_,
users=users,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@@ -627,24 +690,9 @@ class ProjectBLL:
for project in project_ids
}
def filter_child_projects(project: str) -> Sequence[Project]:
non_filtered_children = child_projects.get(project, [])
if not non_filtered_children or return_hidden_children:
return non_filtered_children
return [
c
for c in non_filtered_children
if not c.system_tags
or EntityVisibility.hidden.value not in c.system_tags
]
children = {
project: sorted(
[
{"id": c.id, "name": c.name}
for c in filter_child_projects(project)
],
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
key=itemgetter("name"),
)
for project in project_ids
@@ -663,22 +711,21 @@ class ProjectBLL:
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
projects_query &= Q(id__in=project_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res
return res
@classmethod
def get_project_tags(
@@ -688,21 +735,20 @@ class ProjectBLL:
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod
def get_projects_with_active_user(
@@ -711,7 +757,7 @@ class ProjectBLL:
users: Sequence[str],
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
) -> Sequence[str]:
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
If project ids are specified then filter the results by these project ids
@@ -735,13 +781,16 @@ class ProjectBLL:
res = list(res)
if not res:
return res
return res, res
ids_with_parents = _ids_with_parents(res)
if project_ids:
return [pid for pid in ids_with_parents if pid in project_ids]
user_active_project_ids = _ids_with_parents(res)
filtered_ids = (
list(set(user_active_project_ids) & set(project_ids))
if project_ids
else list(user_active_project_ids)
)
return ids_with_parents
return filtered_ids, user_active_project_ids
@classmethod
def get_task_parents(
@@ -810,32 +859,45 @@ class ProjectBLL:
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any],
users: Sequence[str],
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if users:
conditions["user"] = {"$in": users}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
exclude, include = partition(field_filter, lambda x: x.startswith("-"))
conditions[field] = {
**({"$in": include} if include else {}),
**({"$nin": [e[1:] for e in exclude]} if exclude else {}),
}
return conditions
@classmethod
def calc_own_contents(
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
cls,
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
@@ -848,7 +910,10 @@ class ProjectBLL:
pipeline = [
{
"$match": cls.get_match_conditions(
company=company, project_ids=project_ids, filter_=filter_
company=company,
project_ids=project_ids,
filter_=filter_,
users=users,
)
},
{"$project": {"project": 1}},
@@ -858,10 +923,9 @@ class ProjectBLL:
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}

View File

@@ -8,17 +8,18 @@ from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
TaskUrls,
_schedule_for_delete,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes
from apiserver.timing_context import TimingContext
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -32,60 +33,92 @@ class DeleteProjectResult:
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
).count()
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model):
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).count()
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
name = f"non_archived_{cls.__name__.lower()}s"
if not is_pipeline:
ret[name] = cls.objects(**query).count()
else:
ret[name] = (
cls.objects(**query, type=TaskType.controller).count()
if cls == Task
else 0
)
return ret
def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
company: str,
user: str,
project_id: str,
force: bool,
delete_contents: bool,
delete_external_artifacts=True,
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
if not force:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).only("id")
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
if not is_pipeline:
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
non_archived = cls.objects(**query).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
else:
non_archived = Task.objects(**query, type=TaskType.controller).only("id")
if non_archived:
raise error("use force=true to delete", id=project_id)
raise errors.bad_request.ProjectHasTasks(
"please archive all the runs inside the project", id=project_id
)
if not delete_contents:
with TimingContext("mongo", "update_children"):
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(
project=None
)
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=updated_count)
else:
deleted_models, model_urls = _delete_models(projects=project_ids)
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
scheduled = _schedule_for_delete(
task_id=project_id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=True,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
res = DeleteProjectResult(
deleted_tasks=deleted_tasks,
deleted_models=deleted_models,
@@ -114,9 +147,8 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
return 0, set(), set()
task_ids = {t.id for t in tasks}
with TimingContext("mongo", "delete_tasks_update_children"):
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
event_urls, artifact_urls = set(), set()
for task in tasks:
@@ -131,46 +163,58 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
}
)
event_bll.delete_multi_task_events(company, list(task_ids))
event_bll.delete_multi_task_events(
company, list(task_ids), async_delete=async_events_delete
)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
def _delete_models(
company: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
that reference them to reference None.
"""
with TimingContext("mongo", "delete_models"):
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set()
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set(), set()
model_ids = list({m.id for m in models})
model_ids = list({m.id for m in models})
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
event_urls, model_urls = set(), set()
for m in models:
event_urls.update(collect_debug_image_urls(company, m.id))
event_urls.update(collect_plot_image_urls(company, m.id))
if m.uri:
model_urls.add(m.uri)
urls = {m.uri for m in models if m.uri}
deleted = models.delete()
return deleted, urls
event_bll.delete_multi_task_events(
company, model_ids, async_delete=async_events_delete
)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -4,6 +4,7 @@ from typing import Tuple, Optional, Sequence, Mapping
from apiserver import database
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
name_separator = "/"
@@ -50,6 +51,7 @@ def _ensure_project(
created=now,
last_update=now,
name=name,
basename=name.split("/")[-1],
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location, creation_params=creation_params)
@@ -100,12 +102,21 @@ def _get_writable_project_from_name(
def _get_sub_projects(
project_ids: Sequence[str], _only: Sequence[str] = ("id", "path")
project_ids: Sequence[str],
_only: Sequence[str] = ("id", "path"),
search_hidden=True,
allowed_ids: Sequence[str] = None,
) -> Mapping[str, Sequence[Project]]:
"""
Return the list of child projects of all the levels for the parent project ids
"""
qs = Project.objects(path__in=project_ids)
query = dict(path__in=project_ids)
if not search_hidden:
query["system_tags__nin"] = [EntityVisibility.hidden.value]
if allowed_ids:
query["id__in"] = allowed_ids
qs = Project.objects(**query)
if _only:
_only = set(_only) | {"path"}
qs = qs.only(*_only)

View File

@@ -3,8 +3,10 @@ from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver import database
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.queue.queue_metrics import QueueMetrics
@@ -50,8 +52,25 @@ class QueueBLL(object):
queue.save()
return queue
def get_by_name(
self, company_id: str, queue_name: str, only: Optional[Sequence[str]] = None,
) -> Queue:
qs = Queue.objects(name=queue_name, company=company_id)
if only:
qs = qs.only(*only)
return qs.first()
@staticmethod
def _get_task_entries_projection(max_task_entries: int) -> dict:
return dict(slice__entries=max_task_entries)
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
self,
company_id: str,
queue_id: str,
only: Optional[Sequence[str]] = None,
max_task_entries: int = None,
) -> Queue:
"""
Get queue by id
@@ -62,6 +81,8 @@ class QueueBLL(object):
qs = Queue.objects(**query)
if only:
qs = qs.only(*only)
if max_task_entries:
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
queue = qs.first()
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
@@ -120,16 +141,42 @@ class QueueBLL(object):
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries and not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
if queue.entries:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
from apiserver.bll.task import ChangeStatusRequest
for item in queue.entries:
try:
task = Task.get_for_writing(
company=company_id,
id=item.task,
_only=["id", "status", "enqueue_status", "project"],
)
if not task:
continue
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
f"Failed dequeuing task {item.task} from queue: {queue_id}"
)
queue.delete()
def get_all(
self,
company_id: str,
query_dict: dict,
query: Q = None,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
@@ -138,13 +185,26 @@ class QueueBLL(object):
company=company_id,
parameters=query_dict,
query_dict=query_dict,
query=query,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
for worker in self.worker_bll.get_all(company_id):
if queue_id in worker.queues:
return True
return False
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
query: Q = None,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""
@@ -155,7 +215,11 @@ class QueueBLL(object):
res = Queue.get_many_with_join(
company=company_id,
query_dict=query_dict,
query=query,
override_projection=projection,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
@@ -203,16 +267,22 @@ class QueueBLL(object):
return res
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
def get_next_task(
self, company_id: str, queue_id: str, task_id: str = None
) -> Optional[Entry]:
"""
Atomically pop and return the first task from the queue (or None)
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
queue = Queue.objects(
**query, **({"entries__0__task": task_id} if task_id else {})
).modify(pop__entries=-1, upsert=False)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
if not task_id or not Queue.objects(**query).first():
raise errors.bad_request.InvalidQueueId(**query)
return
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
@@ -288,3 +358,22 @@ class QueueBLL(object):
)
return new_position
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(
Queue.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": queue_id,
}
},
{"$project": {"count": {"$size": "$entries"}}},
]
),
None,
)
if res is None:
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
return int(res.get("count"))

View File

@@ -1,8 +1,10 @@
import json
from collections import defaultdict
from datetime import datetime
from time import sleep
from typing import Sequence
import elasticsearch.helpers
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.es_factory import es_factory
@@ -11,25 +13,30 @@ from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
from apiserver.timing_context import TimingContext
from apiserver.redis_manager import redman
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
_conf = config.get("services.queues")
_queue_metrics_key_pattern = "queue_metrics_{queue}"
redis = redman.connection("apiserver")
class EsKeys:
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
class QueueMetrics:
class EsKeys:
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
def __init__(self, es: Elasticsearch):
self.es = es
@staticmethod
def _queue_metrics_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"queue_metrics_{company_id}_"
return f"queue_metrics_{company_id.lower()}_"
@staticmethod
def _get_es_index_suffix():
@@ -49,7 +56,7 @@ class QueueMetrics:
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
return total_waiting_in_secs / len(entries)
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> int:
"""
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
:return: True if the write to es was successful, false otherwise
@@ -63,23 +70,22 @@ class QueueMetrics:
def make_doc(queue: Queue) -> dict:
entries = [e for e in queue.entries if e.added]
return dict(
_index=es_index,
_source={
self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id,
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
entries
),
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
},
)
return {
EsKeys.TIMESTAMP_FIELD: timestamp,
EsKeys.QUEUE_FIELD: queue.id,
EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(entries),
EsKeys.QUEUE_LENGTH_FIELD: len(entries),
}
actions = list(map(make_doc, queues))
logged = 0
for q in queues:
queue_doc = make_doc(q)
self.es.index(index=es_index, body=queue_doc)
redis_key = _queue_metrics_key_pattern.format(queue=q.id)
redis.set(redis_key, json.dumps(queue_doc))
logged += 1
es_res = elasticsearch.helpers.bulk(self.es, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return logged
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
query = dict(company=company_id)
@@ -90,8 +96,7 @@ class QueueMetrics:
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
body=es_req,
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", body=es_req,
)
@classmethod
@@ -105,13 +110,13 @@ class QueueMetrics:
return {
"dates": {
"date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD,
"field": EsKeys.TIMESTAMP_FIELD,
"fixed_interval": f"{interval}s",
"min_doc_count": 1,
},
"aggs": {
"queues": {
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
"terms": {"field": EsKeys.QUEUE_FIELD},
"aggs": cls._get_top_waiting_agg(),
}
},
@@ -128,13 +133,13 @@ class QueueMetrics:
"top_avg_waiting": {
"top_hits": {
"sort": [
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
{EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
],
"_source": {
"includes": [
cls.EsKeys.WAITING_TIME_FIELD,
cls.EsKeys.QUEUE_LENGTH_FIELD,
EsKeys.WAITING_TIME_FIELD,
EsKeys.QUEUE_LENGTH_FIELD,
]
},
"size": 1,
@@ -149,6 +154,7 @@ class QueueMetrics:
to_date: float,
interval: int,
queue_ids: Sequence[str],
refresh: bool = False,
) -> dict:
"""
Get the company queue metrics in the specified time range.
@@ -158,7 +164,8 @@ class QueueMetrics:
In case no queue ids are specified the avg across all the
company queues is calculated for each metric
"""
# self._log_current_metrics(company, queue_ids=queue_ids)
if refresh:
self._log_current_metrics(company_id, queue_ids=queue_ids)
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
@@ -174,7 +181,7 @@ class QueueMetrics:
"aggs": self._get_dates_agg(interval),
}
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
with translate_errors_context():
res = self._search_company_metrics(company_id, es_req)
if "aggregations" not in res:
@@ -256,7 +263,52 @@ class QueueMetrics:
continue
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
queue_metrics[queue_data["key"]] = {
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
"queue_length": res[EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[EsKeys.WAITING_TIME_FIELD],
}
return queue_metrics
class MetricsRefresher:
threads = ThreadsManager()
@classproperty
def watch_interval_sec(self):
return _conf.get("metrics_refresh_interval_sec", 300)
@classmethod
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
def start(cls, queue_metrics: QueueMetrics = None):
if not cls.watch_interval_sec:
return
if not queue_metrics:
from .queue_bll import QueueBLL
queue_metrics = QueueBLL().metrics
sleep(10)
while True:
try:
for queue in Queue.objects():
timestamp = es_factory.get_timestamp_millis()
doc_time = 0
try:
redis_key = _queue_metrics_key_pattern.format(queue=queue.id)
data = redis.get(redis_key)
if data:
queue_doc = json.loads(data)
doc_time = int(queue_doc.get(EsKeys.TIMESTAMP_FIELD))
except Exception as ex:
log.exception(
f"Error reading queue metrics data for queue {queue.id}: {str(ex)}"
)
if (
not doc_time
or (timestamp - doc_time) > cls.watch_interval_sec * 1000
):
queue_metrics.log_queue_metrics_to_es(queue.company, [queue])
except Exception as ex:
log.exception(f"Failed collecting queue metrics: {str(ex)}")
sleep(60)

View File

@@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
from apiserver import database
from apiserver.timing_context import TimingContext
T = TypeVar("T")
@@ -31,20 +30,17 @@ class RedisCacheManager(Generic[T]):
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"

View File

@@ -1,6 +1,6 @@
from datetime import datetime
import operator
from threading import Thread, Lock
from threading import Lock
from time import sleep
import attr
@@ -9,76 +9,83 @@ import psutil
from apiserver.utilities.threads_manager import ThreadsManager
class ResourceMonitor(Thread):
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
stat_threads = ThreadsManager("Statistics")
@classmethod
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
def __init__(self, sample_interval_sec=5):
super(ResourceMonitor, self).__init__(daemon=True)
self.sample_interval_sec = sample_interval_sec
self._lock = Lock()
self._clear()
def _clear(self):
sample = self._get_sample()
self._avg = sample
self._min = sample
self._max = sample
self._clear_time = datetime.utcnow()
self._count = 1
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
@classmethod
def _get_sample(cls) -> Sample:
return cls.Sample(
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
@classmethod
def get_current_sample(cls) -> "Sample":
return cls(
cpu_usage=psutil.cpu_percent(),
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
)
def run(self):
while not ThreadsManager.terminating:
sleep(self.sample_interval_sec)
sample = self._get_sample()
class ResourceMonitor:
class Accumulator:
def __init__(self):
sample = Sample.get_current_sample()
self.avg = sample
self.min = sample
self.max = sample
self.time = datetime.utcnow()
self.count = 1
with self._lock:
self._min = self._min.min(sample)
self._max = self._max.max(sample)
self._avg = self._avg.avg(sample, self._count)
self._count += 1
def add_sample(self, sample: Sample):
self.min = self.min.min(sample)
self.max = self.max.max(sample)
self.avg = self.avg.avg(sample, self.count)
self.count += 1
def get_stats(self) -> dict:
sample_interval_sec = 5
_lock = Lock()
accumulator = Accumulator()
@classmethod
@stat_threads.register("resource_monitor", daemon=True)
def start(cls):
while True:
sleep(cls.sample_interval_sec)
sample = Sample.get_current_sample()
with cls._lock:
cls.accumulator.add_sample(sample)
@classmethod
def get_stats(cls) -> dict:
""" Returns current resource statistics and clears internal resource statistics """
with self._lock:
min_ = attr.asdict(self._min)
max_ = attr.asdict(self._max)
avg = attr.asdict(self._avg)
interval = datetime.utcnow() - self._clear_time
self._clear()
with cls._lock:
min_ = attr.asdict(cls.accumulator.min)
max_ = attr.asdict(cls.accumulator.max)
avg = attr.asdict(cls.accumulator.avg)
interval = datetime.utcnow() - cls.accumulator.time
cls.accumulator = cls.Accumulator()
return {
"interval_sec": interval.total_seconds(),

View File

@@ -21,9 +21,8 @@ from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.json import dumps
from apiserver.utilities.threads_manager import ThreadsManager
from apiserver.version import __version__ as current_version
from .resource_monitor import ResourceMonitor
from .resource_monitor import ResourceMonitor, stat_threads
log = config.logger(__file__)
@@ -31,17 +30,19 @@ worker_bll = WorkerBLL()
class StatisticsReporter:
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
send_queue = queue.Queue()
supported = config.get("apiserver.statistics.supported", True)
@classmethod
def start(cls):
if not cls.supported:
return
ResourceMonitor.start()
cls.start_sender()
cls.start_reporter()
@classmethod
@threads.register("reporter", daemon=True)
@stat_threads.register("reporter", daemon=True)
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
@@ -54,7 +55,7 @@ class StatisticsReporter:
hours=config.get("apiserver.statistics.report_interval_hours", 24)
)
sleep(report_interval.total_seconds())
while not ThreadsManager.terminating:
while True:
try:
for company in Company.objects(
defaults__stats_option__enabled=True
@@ -68,7 +69,7 @@ class StatisticsReporter:
sleep(report_interval.total_seconds())
@classmethod
@threads.register("sender", daemon=True)
@stat_threads.register("sender", daemon=True)
def start_sender(cls):
if not cls.supported:
return
@@ -85,7 +86,7 @@ class StatisticsReporter:
WarningFilter.attach()
while not ThreadsManager.terminating:
while True:
try:
report = cls.send_queue.get()
@@ -111,7 +112,7 @@ class StatisticsReporter:
"uuid": get_server_uuid(),
"queues": {"count": Queue.objects(company=company_id).count()},
"users": {"count": User.objects(company=company_id).count()},
"resources": cls.threads.resource_monitor.get_stats(),
"resources": ResourceMonitor.get_stats(),
"experiments": next(
iter(cls._get_experiments_stats(company_id).values()), {}
),

View File

@@ -5,7 +5,6 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -53,23 +52,18 @@ class Artifacts:
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
update_cmds = {
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, update_cmds=update_cmds)
update_cmds = {
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
@@ -79,19 +73,14 @@ class Artifacts:
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, update_cmds=delete_cmds)

View File

@@ -15,7 +15,6 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -68,36 +67,35 @@ class HyperParams:
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only
)
return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only
)
@classmethod
def edit_params(
@@ -108,34 +106,31 @@ class HyperParams:
replace_hyperparams: str,
force: bool,
) -> int:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only
)
return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only
)
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
@@ -191,17 +186,16 @@ class HyperParams:
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline)
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
@@ -212,36 +206,30 @@ class HyperParams:
replace_configuration: bool,
force: bool,
) -> int:
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, update_cmds=update_cmds)
return update_task(task, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
) -> int:
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, update_cmds=delete_cmds)

View File

@@ -39,7 +39,7 @@ class NonResponsiveTasksWatchdog:
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start(cls):
sleep(cls.settings.watch_interval_sec)
while not ThreadsManager.terminating:
while True:
watch_interval = cls.settings.watch_interval_sec
if cls.settings.enabled:
try:

View File

@@ -121,18 +121,31 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
nested_set(fields, new_path, new_param)
nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = fields.get(param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
ParameterKeyEscaper.escape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
def ensure_non_empty(k: str, desc: str) -> str:
if not k:
raise errors.bad_request.ValidationError(
f"Empty {desc} name is not allowed"
)
return k
params = fields.get("hyperparams")
if params:
escaped_params = {
ParameterKeyEscaper.escape(ensure_non_empty(key, "section")): {
ParameterKeyEscaper.escape(ensure_non_empty(k, "parameter")): v
for k, v in value.items()
}
fields[param_field] = escaped_params
for key, value in params.items()
}
fields["hyperparams"] = escaped_params
params = fields.get("configuration")
if params:
escaped_params = {
ParameterKeyEscaper.escape(ensure_non_empty(key, "configuration")): value
for key, value in params.items()
}
fields["configuration"] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False):
@@ -186,7 +199,7 @@ def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", "configuration"),
("execution.docker_cmd", "container")
("execution.docker_cmd", "container"),
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
from typing import Collection, Sequence, Tuple, Optional, Dict
import six
from mongoengine import Q
@@ -35,7 +35,6 @@ from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@@ -66,11 +65,10 @@ class TaskBLL:
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
with TimingContext("mongo", "task_with_access"):
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@@ -88,15 +86,14 @@ class TaskBLL:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@@ -111,7 +108,7 @@ class TaskBLL:
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
with translate_errors_context():
ids = set(task_ids)
q = Task.get_many(
company=company_id,
@@ -260,58 +257,55 @@ class TaskBLL:
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else None
)
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination)
if task.output
else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else task.id
)
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
update_project_time(new_task.project)
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
update_project_time(new_task.project)
return new_task, new_project_data
@@ -381,7 +375,7 @@ class TaskBLL:
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
@@ -406,18 +400,43 @@ class TaskBLL:
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_scalar_values is not None:
if last_scalar_events is not None:
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values-to_add}__exists"] = True
task = Task.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
new_metrics = []
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = (
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
)
if max_values:
if (
len(total_metrics) >= max_values
and metric not in total_metrics
):
continue
total_metrics.add(metric)
for path, value in last_scalar_values:
if path[-1] == "min_value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
elif path[-1] == "max_value":
extra_updates[op_path("max", *path[:-1], "max_value")] = value
else:
extra_updates[op_path("set", *path)] = value
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key == "min_value":
extra_updates[f"min__{path}__min_value"] = value
elif key == "max_value":
extra_updates[f"max__{path}__max_value"] = value
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
if last_events is not None:
@@ -446,7 +465,11 @@ class TaskBLL:
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,
):
cls.dequeue(task, company_id)
try:
cls.dequeue(task, company_id)
except errors.bad_request.InvalidQueueOrTaskNotQueued:
# dequeue may fail if the queue was deleted
pass
return ChangeStatusRequest(
task=task,

View File

@@ -1,64 +1,31 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
from typing import Sequence, Set, Tuple
import attr
from boltons.iterutils import partition
from mongoengine import QuerySet, Document
from boltons.iterutils import partition, bucketize, first
from mongoengine import NotUniqueError
from pymongo.errors import DuplicateKeyError
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.timing_context import TimingContext
from apiserver.database.model.url_to_delete import (
StorageType,
UrlToDelete,
FileType,
DeletionStatus,
)
from apiserver.database.utils import id as db_id
log = config.logger(__file__)
event_bll = EventBLL()
T = TypeVar("T", bound=Document)
class DocumentGroup(List[T]):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type: Type[T], documents: Iterable[T]):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
@property
def ids(self) -> Set[str]:
return {obj.id for obj in self}
def objects(self, *args, **kwargs) -> QuerySet:
return self.type.objects(id__in=self.ids, *args, **kwargs)
class TaskOutputs(Generic[T]):
"""
Split task outputs of the same type by the ready state
"""
published: DocumentGroup[T]
draft: DocumentGroup[T]
def __init__(
self,
is_published: Callable[[T], bool],
document_type: Type[T],
children: Iterable[T],
):
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x),
partition(children, key=is_published),
)
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -101,64 +68,112 @@ class CleanupResult:
)
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
urls = set()
next_scroll_id = None
with TimingContext("es", "collect_plot_image_urls"):
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
metrics = event_bll.get_metrics_and_variants(
company_id=company, task_id=task, event_type=EventType.metrics_image
)
if not metrics:
return set()
task_metrics = {task: {m: [] for m in metrics}}
scroll_id = None
after_key = None
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company,
task_metrics=task_metrics,
iter_count=10,
state_id=scroll_id,
res, after_key = event_bll.get_debug_image_urls(
company_id=company, task_id=task_or_model, after_key=after_key,
)
if not res.metric_events or not any(
iterations for _, iterations in res.metric_events
):
urls.update(res)
if not after_key:
break
scroll_id = res.next_scroll_id
for task, iterations in res.metric_events:
urls.update(ev.get("url") for it in iterations for ev in it["events"])
urls.discard({None})
return urls
supported_storage_types = {
"https://": StorageType.fileserver,
"http://": StorageType.fileserver,
}
def _schedule_for_delete(
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
) -> Set[str]:
urls_per_storage = bucketize(
urls,
key=lambda u: first(
type_
for prefix, type_ in supported_storage_types.items()
if u.startswith(prefix)
),
)
urls_per_storage.pop(None, None)
processed_urls = set()
for storage_type, storage_urls in urls_per_storage.items():
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
scheduled_to_delete = set()
for url in storage_urls:
folder = None
if delete_folders:
folder, _, _ = url.rpartition("/")
to_delete = folder or url
if to_delete in scheduled_to_delete:
processed_urls.add(url)
continue
try:
UrlToDelete(
id=db_id(),
company=company,
user=user,
url=to_delete,
task=task_id,
created=datetime.utcnow(),
storage_type=storage_type,
type=FileType.folder if folder else FileType.file,
).save()
except (DuplicateKeyError, NotUniqueError):
existing = UrlToDelete.objects(company=company, url=to_delete).first()
if existing:
existing.update(
user=user,
task=task_id,
created=datetime.utcnow(),
retry_count=0,
unset__last_failure_time=1,
unset__last_failure_reason=1,
status=DeletionStatus.created,
)
processed_urls.add(url)
scheduled_to_delete.add(to_delete)
return processed_urls
def cleanup_task(
company: str,
user: str,
task: Task,
force: bool = False,
update_children=True,
return_file_urls=False,
delete_output_models=True,
delete_external_artifacts=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
@@ -166,10 +181,14 @@ def cleanup_task(
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models = verify_task_children_and_ouptuts(task, force)
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls:
if return_file_urls or delete_external_artifacts:
event_urls = collect_debug_image_urls(task.company, task.id)
event_urls.update(collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
@@ -178,30 +197,63 @@ def cleanup_task(
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
model_urls = {
m.uri for m in draft_models if m.uri and m.id not in in_use_model_ids
}
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
deleted_models = 0
updated_models = 0
for models, allow_delete in ((draft_models, True), (published_models, False)):
if not models:
continue
if delete_output_models and allow_delete:
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
for m_id in model_ids:
if return_file_urls or delete_external_artifacts:
event_urls.update(collect_debug_image_urls(task.company, m_id))
event_urls.update(collect_plot_image_urls(task.company, m_id))
try:
event_bll.delete_task_events(
task.company,
m_id,
allow_locked=True,
model=True,
async_delete=async_events_delete,
)
except errors.bad_request.InvalidModelId as ex:
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
deleted_models += Model.objects(id__in=list(model_ids)).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
continue
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id
)
else:
updated_children = 0
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
if models.draft and delete_output_models:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
event_bll.delete_task_events(
task.company, task.id, allow_locked=force, async_delete=async_events_delete
)
if models.published and update_children:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
if delete_external_artifacts:
scheduled = _schedule_for_delete(
task_id=task.id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=not in_use_model_ids and not published_models,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
return CleanupResult(
deleted_models=deleted_models,
@@ -217,62 +269,56 @@ def cleanup_task(
)
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
def verify_task_children_and_ouptuts(
task, force: bool
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
)
if not force and models.published:
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
"has children, use force=True",
task=task.id,
models=len(models.published),
children=published_children_count,
)
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
)
if not force and published_models:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(published_models),
)
if task.models and task.models.output:
with TimingContext("mongo", "get_task_output_model"):
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
models.published.append(output_model)
else:
models.draft.append(output_model)
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids).only(*model_fields):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=output_model.id,
)
published_models.append(output_model)
else:
draft_models.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models"
in_use_model_ids = {}
if draft_models:
model_ids = {m.id for m in draft_models}
dependent_tasks = Task.objects(models__input__model__in=list(model_ids)).only(
"id", "models"
)
in_use_model_ids = model_ids & {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
input_models = {
m.model
for m in chain.from_iterable(
t.models.input for t in dependent_tasks if t.models
)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)
)
}
return models
return published_models, draft_models, in_use_model_ids

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable, Any, Tuple, Union
from typing import Callable, Any, Tuple, Union, Sequence
from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL
@@ -25,6 +25,7 @@ from apiserver.database.model.task.task import (
)
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
queue_bll = QueueBLL()
@@ -83,10 +84,7 @@ def unarchive_task(
def dequeue_task(
task_id: str,
company_id: str,
status_message: str,
status_reason: str,
task_id: str, company_id: str, status_message: str, status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
@@ -94,10 +92,7 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task,
company_id,
status_message=status_message,
status_reason=status_reason,
task, company_id, status_message=status_message, status_reason=status_reason,
)
return 1, res
@@ -108,9 +103,23 @@ def enqueue_task(
queue_id: str,
status_message: str,
status_reason: str,
queue_name: str = None,
validate: bool = False,
force: bool = False,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
)
if not queue:
queue = queue_bll.create(company_id=company_id, name=queue_name)
queue_id = queue.id
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
@@ -155,15 +164,41 @@ def enqueue_task(
return 1, res
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
try:
collection_name = Task._get_collection_name()
trash_collection_name = f"{collection_name}__trash"
Task.aggregate(
[
{"$match": {"_id": {"$in": tasks}}},
{
"$merge": {
"into": trash_collection_name,
"on": "_id",
"whenMatched": "replace",
"whenNotMatched": "insert",
}
},
],
allow_disk_use=True,
)
except Exception as ex:
log.error(f"Error copying tasks to trash {str(ex)}")
return Task.objects(id__in=tasks).delete()
def delete_task(
task_id: str,
company_id: str,
user_id: str,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
@@ -193,25 +228,22 @@ def delete_task(
pass
cleanup_res = cleanup_task(
task,
company=company_id,
user=user_id,
task=task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
collection_name = task._get_collection_name()
archived_collection = "{}__trash".format(collection_name)
task.switch_collection(archived_collection)
try:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.save()
else:
task.delete()
task.delete()
update_project_time(task.project)
return 1, task, cleanup_res
@@ -219,10 +251,12 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
@@ -241,16 +275,20 @@ def reset_task(
pass
cleaned_up = cleanup_task(
task,
company=company_id,
user=user_id,
task=task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__unique_metrics=[],
set__metric_stats={},
set__models__output=[],
set__runtime={},

View File

@@ -9,7 +9,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -55,7 +54,7 @@ class ChangeStatusRequest(object):
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context(), TimingContext("mongo", "task_status"):
with translate_errors_context():
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from apiserver.apierrors import errors
from apiserver.apimodels.users import CreateRequest
from apiserver.database.errors import translate_errors_context
@@ -12,7 +14,7 @@ class UserBLL:
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
user = User(**request.to_struct())
user = User(**request.to_struct(), created=datetime.utcnow())
user.save(force_insert=True)
@staticmethod

View File

@@ -1,9 +1,11 @@
import itertools
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
@@ -25,7 +27,6 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from .stats import WorkerStats
@@ -51,6 +52,7 @@ class WorkerBLL:
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@@ -76,7 +78,7 @@ class WorkerBLL:
raise bad_request.InvalidUserId(**query)
company = Company.objects(id=company_id).only("id", "name").first()
if not company:
raise server_error.InternalError("invalid company", company=company_id)
raise bad_request.InvalidId("invalid company", company=company_id)
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
if len(queue_objs) < len(queues):
@@ -95,9 +97,10 @@ class WorkerBLL:
register_timeout=timeout,
last_activity_time=now,
tags=tags,
system_tags=system_tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
self._save_worker_data(entry)
return entry
@@ -109,15 +112,20 @@ class WorkerBLL:
:param worker: worker ID
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
"""
with TimingContext("redis", "workers_unregister"):
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res and not config.get("apiserver.workers.auto_unregister", False):
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
self,
company_id: str,
user_id: str,
ip: str,
report: StatusReportRequest,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
@@ -138,12 +146,14 @@ class WorkerBLL:
if tags is not None:
entry.tags = tags
if system_tags is not None:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=report.worker,
worker=entry.key,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
@@ -176,7 +186,9 @@ class WorkerBLL:
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.project = IdNameEntry(
id=project.id, name=project.name
)
entry.last_report_time = now
except APIError:
@@ -189,7 +201,11 @@ class WorkerBLL:
self._save_worker(entry)
def get_all(
self, company_id: str, last_seen: Optional[int] = None
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -198,7 +214,7 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id)
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@@ -213,13 +229,22 @@ class WorkerBLL:
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int
self,
company_id: str,
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen),
self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
)
)
@@ -310,8 +335,7 @@ class WorkerBLL:
"""
key = self._get_worker_key(company_id, user_id, worker)
with TimingContext("redis", "get_worker"):
data = self.redis.get(key)
data = self.redis.get(key)
if data:
try:
@@ -338,24 +362,119 @@ class WorkerBLL:
raise bad_request.InvalidWorkerId(worker=worker)
@staticmethod
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers.{tags_field}_{company}_{tag}"
@staticmethod
def _get_all_workers_key(company: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers_{company}"
def _save_worker_data(self, entry: WorkerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
company_id = entry.company.id
expiration = int(time()) + entry.register_timeout
worker_item = {entry.key: expiration}
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
for tags, tags_field in (
(entry.tags, "tags"),
(entry.system_tags, "systemtags"),
):
for tag in tags:
name = self._get_tagged_workers_key(company_id, tags_field, tag)
self.redis.zadd(name, worker_item)
def _save_worker(self, entry: WorkerEntry) -> None:
"""Save worker entry in Redis"""
try:
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
self._save_worker_data(entry)
except Exception:
msg = "Failed saving worker entry"
log.exception(msg)
def _get(
self, company: str, user: str = "*", worker_id: str = "*"
self,
company: str,
user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
match = self._get_worker_key(company, user, worker_id)
with TimingContext("redis", "workers_get_all"):
res = self.redis.scan_iter(match)
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags:
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(
all_workers_key, min=0, max=timestamp
)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
entries = []
for key in worker_keys:
data = self.redis.get(key)
if data:
entries.append(WorkerEntry.from_json(data))
return entries
@staticmethod
def _get_es_index_suffix():

View File

@@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@@ -20,7 +19,7 @@ class WorkerStats:
@staticmethod
def worker_stats_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id}_"
return f"worker_stats_{company_id.lower()}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
@@ -126,7 +125,7 @@ class WorkerStats:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
@@ -223,9 +222,7 @@ class WorkerStats:
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext(
"es", "get_worker_activity_report"
):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
if "aggregations" not in data:

View File

@@ -146,4 +146,11 @@
max_backoff_sec: 5
}
getting_started_info {
"agentName": "clearml",
"configure": "clearml-init",
"install": "pip install clearml",
"packageName": "clearml"
}
}

View File

@@ -1,3 +1,5 @@
fileserver = "http://localhost:8081"
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]

View File

@@ -0,0 +1,12 @@
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
# otherwise they are returned to a client for the client side delete
enabled: false
max_retries: 3
retry_timeout_sec: 60
fileserver {
# fileserver url prefixes. Evaluated in the order of priority
# Can be in the form <schema>://host:port/path or /path
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
timeout_sec: 300
}

View File

@@ -12,12 +12,23 @@ events_retrieval {
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
# If set then max_metrics_count and max_variants_count are calculated dynamically on user data
dynamic_metrics_count: true
# The percentage from the ES aggs limit (10000) to use for the max_metrics and max_variants calculation
dynamic_metrics_count_threshold: 80
# the max amount of metrics to aggregate on
max_metrics_count: 100
# the max amount of variants to aggregate on
max_variants_count: 100
debug_images {
# Allow to return the debug images for the variants with uninitialized valid iterations border
allow_uninitialized_variants: true
}
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
@@ -28,4 +39,7 @@ events_retrieval {
validate_plot_str: false
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
plot_compression_threshold: 100000
plot_compression_threshold: 100000
# async events delete threshold
max_async_deleted_events_per_sec: 1000

View File

@@ -0,0 +1,8 @@
{
metrics_before_from_date: 3600
# interval in seconds to update queue metrics. Put 0 to disable
metrics_refresh_interval_sec: 300
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
# or search_hidden is set
hidden_tags: [k8s-glue]
}

View File

@@ -19,4 +19,11 @@ hyperparam_values {
# cache ttl sec
cache_ttl_sec: 86400
}
}
# the maximum amount of unique last metrics/variants combinations
# for which the last values are stored in a task
max_last_metrics: 2000
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
async_events_delete: false

View File

@@ -29,6 +29,9 @@ OVERRIDE_PORT_ENV_KEY = (
)
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
class DatabaseEntry(models.Base):
@@ -52,14 +55,23 @@ class DatabaseFactory:
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
if override_connection_string:
log.info(f"Using override mongodb connection string {override_connection_string}")
log.info(f"Using override mongodb connection string template {override_connection_string}")
else:
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
if override_port:
log.info(f"Using override mongodb port {override_port}")
if override_username:
log.info(f"Using override mongodb username {override_username}")
if override_password:
log.info(f"Using override mongodb password ******")
if override_query:
log.info(f"Using override mongodb query {override_query}")
for key, alias in get_items(Database).items():
if key not in db_entries:
@@ -69,12 +81,20 @@ class DatabaseFactory:
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_connection_string:
entry.host = override_connection_string
con_str = f"{override_connection_string.rstrip('/')}/{key}"
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
entry.host = con_str
else:
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_username:
entry.host = furl(entry.host).set(username=override_username).url
if override_password:
entry.host = furl(entry.host).set(password=override_password).url
if override_query:
entry.host = furl(entry.host).set(query=override_query).url
try:
entry.validate()

View File

@@ -166,7 +166,10 @@ class MongoEngineErrorsHandler(object):
@classmethod
@throws_default_error(errors.server_error.InternalError)
def invalid_query_error(cls, e, message, **_):
pass
if e.args:
inner = e.args[0]
if isinstance(inner, LookUpError):
cls.lookup_error(inner, message)
@contextmanager

View File

@@ -26,7 +26,7 @@ from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.projection import ProjectionHelper
from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import (
@@ -36,6 +36,7 @@ from apiserver.database.utils import (
field_exists,
)
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
log = config.logger("dbmodel")
@@ -55,17 +56,25 @@ class ProperDictMixin(object):
strip_private=True,
only=None,
extra_dict=None,
exclude=None,
) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
only=only,
extra_dict=extra_dict,
exclude=exclude,
)
@classmethod
def properize_dict(
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
cls,
d,
strip_private=True,
only=None,
extra_dict=None,
exclude=None,
normalize_id=True,
):
res = d
if normalize_id and "_id" in res:
@@ -76,6 +85,9 @@ class ProperDictMixin(object):
res = project_dict(res, only)
if extra_dict:
res.update(extra_dict)
if exclude:
exclude_fields_from_dict(res, exclude)
return res
@@ -368,18 +380,31 @@ class GetMixin(PropsMixin):
if data is not None:
if not isinstance(data, list):
data = [data]
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = field if not modifier else "__".join((field, modifier))
dict_query[f] = value
except (ValueError, OverflowError):
pass
# date time fields also support simplified range queries. Check if this is the case
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else:
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
@@ -497,6 +522,8 @@ class GetMixin(PropsMixin):
def validate_order_by(cls, parameters, search_text) -> Sequence:
"""
Validate and extract order_by params as a list
If ordering is specified then make sure that id field is part of it
This guarantees the unique order when paging
"""
order_by = parameters.get(cls._ordering_key)
if not order_by:
@@ -509,6 +536,9 @@ class GetMixin(PropsMixin):
"text score cannot be used in order_by when search text is not used"
)
if not any(id_field in order_by for id_field in ("id", "-id")):
order_by.append("id")
return order_by
@classmethod
@@ -648,6 +678,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
projection_fields: dict = None,
ret_params: dict = None,
):
"""
@@ -684,6 +715,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
projection_fields=projection_fields,
ret_params=ret_params,
)
@@ -704,6 +736,45 @@ class GetMixin(PropsMixin):
v for k, v in cls._field_collation_overrides.items() if field.startswith(k)
)
@classmethod
def get_count(
cls: Union["GetMixin", Document],
company,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
allow_public=False,
) -> int:
_query = cls._get_combined_query(
company=company,
query_dict=query_dict,
query_options=query_options,
query=query,
allow_public=allow_public,
)
return cls.objects(_query).count()
@classmethod
def _get_combined_query(
cls,
company,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
allow_public=False,
) -> Q:
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
company=company,
parameters_options=query_options,
allow_public=allow_public,
)
else:
q = cls._prepare_perm_query(company, allow_public=allow_public)
return (q & query) if query else q
@classmethod
def get_many(
cls,
@@ -715,6 +786,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
projection_fields: dict = None,
ret_params: dict = None,
):
"""
@@ -749,16 +821,13 @@ class GetMixin(PropsMixin):
if override_collation:
break
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
company=company,
parameters_options=query_options,
allow_public=allow_public,
)
else:
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
_query = cls._get_combined_query(
company=company,
query_dict=query_dict,
query_options=query_options,
query=query,
allow_public=allow_public,
)
if return_dicts:
data_getter = partial(
@@ -767,6 +836,7 @@ class GetMixin(PropsMixin):
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
@@ -777,6 +847,7 @@ class GetMixin(PropsMixin):
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
projection_fields=projection_fields,
)
@classmethod
@@ -801,6 +872,7 @@ class GetMixin(PropsMixin):
parameters=None,
override_projection=None,
override_collation=None,
projection_fields: dict = None,
):
"""
Fetch all documents matching a provided query.
@@ -843,6 +915,9 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if projection_fields:
qs = qs.fields(**projection_fields)
if start is not None and size:
# add paging
qs = qs.skip(start).limit(size)
@@ -884,6 +959,7 @@ class GetMixin(PropsMixin):
parameters: dict = None,
override_projection: Collection[str] = None,
override_collation: dict = None,
projection_fields: dict = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
@@ -941,8 +1017,15 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if projection_fields:
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
return [
obj.to_proper_dict(only=include, exclude=exclude)
for qs in query_sets
for obj in qs
]
# add paging
ret = []
@@ -950,7 +1033,7 @@ class GetMixin(PropsMixin):
for i, qs in enumerate(query_sets):
last_size = len(ret)
ret.extend(
obj.to_proper_dict(only=include)
obj.to_proper_dict(only=include, exclude=exclude)
for obj in (qs.skip(start) if start else qs).limit(size)
)
added = len(ret) - last_size

View File

@@ -36,6 +36,7 @@ class Model(AttributedDocument):
("company", "framework"),
("company", "name"),
("company", "user"),
("company", "uri"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -91,3 +92,6 @@ class Model(AttributedDocument):
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@@ -9,7 +9,7 @@ from apiserver.database.model.base import GetMixin
class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
pattern_fields=("name", "basename", "description"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
@@ -21,6 +21,7 @@ class Project(AttributedDocument):
"parent",
"path",
("company", "name"),
("company", "basename"),
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
@@ -37,6 +38,7 @@ class Project(AttributedDocument):
min_length=3,
sparse=True,
)
basename = StrippedStringField(required=True)
description = StringField()
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True))

View File

@@ -259,6 +259,7 @@ class Task(AttributedDocument):
last_change = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds

View File

@@ -0,0 +1,51 @@
from enum import Enum
from mongoengine import StringField, DateTimeField, IntField, EnumField
from apiserver.database import Database, strict
from apiserver.database.model import AttributedDocument
class StorageType(str, Enum):
fileserver = "fileserver"
unknown = "unknown"
class FileType(str, Enum):
file = "file"
folder = "folder"
class DeletionStatus(str, Enum):
created = "created"
retrying = "retrying"
failed = "failed"
class UrlToDelete(AttributedDocument):
_field_collation_overrides = {
"url": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
("company", "user", "task"),
"storage_type",
"created",
"retry_count",
"type",
],
}
id = StringField(primary_key=True)
url = StringField(required=True, unique_with="company")
task = StringField(required=True)
created = DateTimeField(required=True)
storage_type = EnumField(StorageType, default=StorageType.unknown)
type = EnumField(FileType, default=FileType.file)
retry_count = IntField(default=0)
last_failure_time = DateTimeField()
last_failure_reason = StringField()
status = EnumField(DeletionStatus, default=DeletionStatus.created)

View File

@@ -1,4 +1,4 @@
from mongoengine import Document, StringField, DynamicField
from mongoengine import Document, StringField, DynamicField, DateTimeField
from apiserver.database import Database, strict
from apiserver.database.model import DbModelMixin
@@ -20,3 +20,4 @@ class User(DbModelMixin, Document):
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = DynamicField(default="", exclude_by_default=True)
created = DateTimeField()

View File

@@ -1,9 +1,7 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
from typing import Sequence, Dict, Callable, Tuple, Any, Type
import dpath.path
from typing import Sequence, Dict, Callable
from apiserver.apierrors import errors
from apiserver.database.props import PropsMixin
@@ -11,65 +9,6 @@ from apiserver.database.props import PropsMixin
SEP = "."
def project_dict(data, projection, separator=SEP):
"""
Project partial data from a dictionary into a new dictionary
:param data: Input dictionary
:param projection: List of dictionary paths (each a string with field names separated using a separator)
:param separator: Separator (default is '.')
:return: A new dictionary containing only the projected parts from the original dictionary
"""
assert isinstance(data, dict)
result = {}
def copy_path(path_parts, source, destination):
src, dst = source, destination
try:
for depth, path_part in enumerate(path_parts[:-1]):
src_part = src[path_part]
if isinstance(src_part, dict):
src = src_part
dst = dst.setdefault(path_part, {})
elif isinstance(src_part, (list, tuple)):
if path_part not in dst:
dst[path_part] = [{} for _ in range(len(src_part))]
elif not isinstance(dst[path_part], (list, tuple)):
raise TypeError(
"Incompatible destination type %s for %s (list expected)"
% (type(dst), separator.join(path_parts[: depth + 1]))
)
elif not len(dst[path_part]) == len(src_part):
raise ValueError(
"Destination list length differs from source length for %s"
% separator.join(path_parts[: depth + 1])
)
dst[path_part] = [
copy_path(path_parts[depth + 1 :], s, d)
for s, d in zip(src_part, dst[path_part])
]
return destination
else:
raise TypeError(
"Unsupported projection type %s for %s"
% (type(src), separator.join(path_parts[: depth + 1]))
)
last_part = path_parts[-1]
dst[last_part] = src[last_part]
except KeyError:
# Projection field not in source, no biggie.
pass
return destination
for projection_path in sorted(projection):
copy_path(
path_parts=projection_path.split(separator), source=data, destination=result
)
return result
class _ReferenceProxy(dict):
def __init__(self, id):
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
@@ -110,9 +49,6 @@ class ProjectionHelper(object):
self._ref_projection = None
self._proxy_manager = _ProxyManager()
# Cached dpath paths for each of the result documents
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
self._parse_projection(projection)
def _collect_projection_fields(self, doc_cls, projection):
@@ -275,25 +211,26 @@ class ProjectionHelper(object):
norm_path = doc_cls.get_dpath_translated_path(path)
globlist = norm_path.strip(SEP).split(SEP)
obj_paths = self._cached_results_paths.get(id(obj))
if obj_paths is None:
obj_paths = self._cached_results_paths[id(obj)] = list(
dpath.path.paths(obj, dirs=True, skip=True)
)
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
def _search_and_replace(target: dict, p: Sequence[str]) -> Sequence[str]:
parent = None
target = obj
for part in p:
parent = target
target = target[part[0]]
if parent and factory:
parent[p[-1][0]] = factory(target)
return target
for idx, part in enumerate(p):
if isinstance(target, dict) and part in target:
parent = target
target = target[part]
elif isinstance(target, list) and part == "*":
return list(
chain.from_iterable(
_search_and_replace(t, p[idx + 1 :]) for t in target
)
)
else:
return []
return [search_and_replace(p) for p in paths]
if parent and factory:
parent[p[-1]] = factory(target)
return [target]
return _search_and_replace(obj, globlist)
def project(self, results, projection_func):
"""

View File

@@ -1,6 +1,6 @@
import hashlib
from inspect import ismethod, getmembers
from typing import Sequence, Tuple, Set, Optional, Callable, Any
from typing import Sequence, Tuple, Set, Optional, Callable, Any, Mapping
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
@@ -203,18 +203,22 @@ def _names_set(*names: str) -> Set[str]:
return set(names) | set(f"-{name}" for name in names)
system_tag_names = {
_system_tag_names = {
"model": _names_set("active", "archived"),
"project": _names_set("archived", "public", "default"),
"task": _names_set("active", "archived", "development"),
"queue": _names_set("default"),
}
system_tag_prefixes = {"task": _names_set("annotat")}
_system_tag_prefixes = {"task": _names_set("annotat")}
def partition_tags(
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
entity: str,
tags: Sequence[str],
system_tags: Optional[Sequence[str]] = (),
system_tag_names: Mapping = _system_tag_names,
system_tag_prefixes: Mapping = _system_tag_prefixes,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Partition the given tags sequence into system and user-defined tags

View File

@@ -35,6 +35,12 @@
},
"value": {
"type": "float"
},
"company_id": {
"type": "keyword"
},
"model_event": {
"type": "boolean"
}
}
}

View File

@@ -0,0 +1,211 @@
from argparse import ArgumentParser
from collections import defaultdict
from datetime import datetime, timedelta
from functools import partial
from itertools import chain
from pathlib import Path
from time import sleep
from typing import Sequence, Tuple
import requests
from furl import furl
from mongoengine import Q
from apiserver.config_repo import config
from apiserver.database import db
from apiserver.database.model.url_to_delete import (
UrlToDelete,
DeletionStatus,
StorageType,
)
log = config.logger(f"JOB-{Path(__file__).name}")
conf = config.get("services.async_urls_delete")
max_retries = conf.get("max_retries", 3)
retry_timeout = timedelta(seconds=conf.get("retry_timeout_sec", 60))
fileserver_timeout = conf.get("fileserver.timeout_sec", 300)
UrlPrefix = Tuple[str, str]
def validate_fileserver_access(fileserver_host: str) -> str:
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
if not fileserver_host:
log.error(f"Fileserver host not configured")
exit(1)
res = requests.get(url=fileserver_host)
res.raise_for_status()
return fileserver_host
def mark_retry_failed(ids: Sequence[str], reason: str):
UrlToDelete.objects(id__in=ids).update(
last_failure_time=datetime.utcnow(),
last_failure_reason=reason,
inc__retry_count=1,
)
UrlToDelete.objects(id__in=ids, retry_count__gte=max_retries).update(
status=DeletionStatus.failed
)
def mark_failed(query: Q, reason: str):
UrlToDelete.objects(query).update(
status=DeletionStatus.failed,
last_failure_time=datetime.utcnow(),
last_failure_reason=reason,
)
def delete_fileserver_urls(
urls_query: Q, fileserver_host: str, url_prefixes: Sequence[UrlPrefix]
):
to_delete = list(UrlToDelete.objects(urls_query).limit(10000))
if not to_delete:
return
def resolve_path(url_: UrlToDelete) -> str:
parsed = furl(url_.url)
url_host = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme else None
url_path = str(parsed.path)
for host, path_prefix in url_prefixes:
if host and url_host != host:
continue
if path_prefix and not url_path.startswith(path_prefix + "/"):
continue
return url_path[len(path_prefix or ""):]
raise ValueError("could not map path")
paths = set()
path_to_id_mapping = defaultdict(list)
for url in to_delete:
try:
path = resolve_path(url)
path = path.strip("/")
if not path:
raise ValueError("Empty path")
except Exception as ex:
err = str(ex)
log.warn(f"Error getting path for {url.url}: {err}")
mark_failed(Q(id=url.id), err)
continue
paths.add(path)
path_to_id_mapping[path].append(url.id)
if not paths:
return
ids_to_delete = set(chain.from_iterable(path_to_id_mapping.values()))
try:
res = requests.post(
url=furl(fileserver_host).add(path="delete_many").url,
json={"files": list(paths)},
timeout=fileserver_timeout,
)
res.raise_for_status()
except Exception as ex:
err = str(ex)
log.warn(f"Error deleting {len(paths)} files from fileserver: {err}")
mark_retry_failed(list(ids_to_delete), err)
return
res_data = res.json()
deleted_ids = set(
chain.from_iterable(
path_to_id_mapping.get(path, [])
for path in list(res_data.get("deleted", {}))
)
)
if deleted_ids:
UrlToDelete.objects(id__in=list(deleted_ids)).delete()
log.info(f"{len(deleted_ids)} files deleted from the fileserver")
failed_ids = set()
for err, error_ids in res_data.get("errors", {}).items():
error_ids = list(
chain.from_iterable(path_to_id_mapping.get(path, []) for path in error_ids)
)
mark_retry_failed(error_ids, err)
log.warning(
f"Failed to delete {len(error_ids)} files from the fileserver due to: {err}"
)
failed_ids.update(error_ids)
missing_ids = ids_to_delete - deleted_ids - failed_ids
if missing_ids:
mark_retry_failed(list(missing_ids), "Not succeeded")
def _get_fileserver_url_prefixes(fileserver_host: str) -> Sequence[UrlPrefix]:
def _parse_url_prefix(prefix) -> UrlPrefix:
url = furl(prefix)
host = f"{url.scheme}://{url.netloc}" if url.scheme else None
return host, str(url.path).rstrip("/")
url_prefixes = [
_parse_url_prefix(p) for p in conf.get("fileserver.url_prefixes", [])
]
if not any(fileserver_host == host for host, _ in url_prefixes):
url_prefixes.append((fileserver_host, ""))
return url_prefixes
def run_delete_loop(fileserver_host: str):
fileserver_host = validate_fileserver_access(fileserver_host)
storage_delete_funcs = {
StorageType.fileserver: partial(
delete_fileserver_urls,
fileserver_host=fileserver_host,
url_prefixes=_get_fileserver_url_prefixes(fileserver_host),
),
}
while True:
now = datetime.utcnow()
urls_query = (
Q(status__ne=DeletionStatus.failed)
& Q(retry_count__lt=max_retries)
& (
Q(last_failure_time__exists=False)
| Q(last_failure_time__lt=now - retry_timeout)
)
)
url_to_delete: UrlToDelete = UrlToDelete.objects(
urls_query & Q(storage_type__in=list(storage_delete_funcs))
).order_by("retry_count").limit(1).first()
if not url_to_delete:
sleep(10)
continue
company = url_to_delete.company
user = url_to_delete.user
storage_type = url_to_delete.storage_type
log.info(
f"Deleting {storage_type} objects for company: {company}, user: {user}"
)
company_storage_urls_query = urls_query & Q(
company=company, storage_type=storage_type,
)
storage_delete_funcs[storage_type](urls_query=company_storage_urls_query)
def main():
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"--fileserver-host", "-fh", help="Fileserver host address", type=str,
)
args = parser.parse_args()
db.initialize()
run_delete_loop(args.fileserver_host)
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,10 @@
import importlib.util
from datetime import datetime
from inspect import signature
from logging import Logger
from pathlib import Path
import pymongo.database
from mongoengine.connection import get_db
from packaging.version import Version, parse
@@ -80,8 +82,15 @@ def _apply_migrations(log: Logger):
if not func:
continue
try:
sig = signature(func)
kwargs = {}
if len(sig.parameters) == 2:
name, param = list(sig.parameters.items())[-1]
key = name.replace("_", "-")
if issubclass(param.annotation, pymongo.database.Database) and key in dbs:
kwargs[name] = get_db(key)
log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
func(get_db(alias), **kwargs)
except Exception:
log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError(

View File

@@ -24,6 +24,7 @@ from typing import (
Callable,
)
from urllib.parse import unquote, urlparse
from uuid import uuid4, UUID, uuid5
from zipfile import ZipFile, ZIP_BZIP2
import mongoengine
@@ -71,6 +72,8 @@ class PrePopulate:
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
flags=re.IGNORECASE,
)
_name_guid_ns = UUID("bda3acc1-e612-506c-bade-80071b6cf039")
_example_id_prefix = "e-"
task_cls: Type[Task]
project_cls: Type[Project]
model_cls: Type[Model]
@@ -690,6 +693,58 @@ class PrePopulate:
continue
yield clean
@staticmethod
def _new_id(_):
return str(uuid4()).replace("-", "")
@classmethod
def _hash_id(cls, name: str):
return str(uuid5(cls._name_guid_ns, name)).replace("-", "")
@classmethod
def _example_id(cls, orig_id: str):
if not orig_id or orig_id.startswith(cls._example_id_prefix):
return orig_id
return cls._example_id_prefix + orig_id
@classmethod
def _private_id(cls, orig_id: str):
if not orig_id or not orig_id.startswith(cls._example_id_prefix):
return orig_id
return orig_id[len(cls._example_id_prefix) :]
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
) -> Mapping[str, str]:
if not metadata or not any(
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
):
return {}
ids = {}
for entity_file in entity_files:
with reader.open(entity_file) as f:
is_project = splitext(entity_file.orig_filename)[0].endswith(".Project")
if metadata.get("new_ids"):
id_func = cls._new_id
elif metadata.get("example_ids"):
id_func = cls._example_id if not is_project else cls._hash_id
elif metadata.get("private_ids"):
id_func = cls._private_id if not is_project else cls._new_id
for item in cls.json_lines(f):
doc = json.loads(item)
orig_id = doc.get("_id")
if orig_id:
ids[orig_id] = (
id_func(orig_id)
if id_func != cls._hash_id
else id_func(doc.get("name"))
)
return ids
@classmethod
def _import(
cls,
@@ -704,37 +759,42 @@ class PrePopulate:
Start from entities since event import will require the tasks already in DB
"""
event_file_ending = cls.events_file_suffix + ".json"
entity_files = (
entity_files = [
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
)
]
metadata = metadata or {}
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
tasks = []
for entity_file in entity_files:
with reader.open(entity_file) as f:
full_name = splitext(entity_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
res = cls._import_entity(f, full_name, company_id, user_id, metadata)
res = cls._import_entity(
f, full_name, company_id, user_id, metadata, old_to_new_ids
)
if res:
tasks = res
if sort_tasks_by_last_updated:
tasks = sorted(tasks, key=attrgetter("last_update"))
new_to_old_ids = {v: k for k, v in old_to_new_ids.items()}
for task in tasks:
old_task_id = new_to_old_ids.get(task.id, task.id)
events_file = first(
fi
for fi in reader.filelist
if fi.orig_filename.endswith(task.id + event_file_ending)
if fi.orig_filename.endswith(old_task_id + event_file_ending)
)
if not events_file:
continue
with reader.open(events_file) as f:
full_name = splitext(events_file.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
cls._import_events(f, full_name, company_id, user_id)
cls._import_events(f, company_id, user_id, task.id)
@classmethod
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
@@ -746,6 +806,15 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_project_data(project_data: dict) -> dict:
if not project_data.get("basename"):
name: str = project_data["name"]
_, _, basename = name.rpartition("/")
project_data["basename"] = basename
return project_data
@staticmethod
def _upgrade_model_data(model_data: dict) -> dict:
metadata_key = "metadata"
@@ -838,6 +907,7 @@ class PrePopulate:
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
old_to_new_ids: Mapping[str, str] = None,
) -> Optional[Sequence[Task]]:
cls_ = cls._get_entity_type(full_name)
print(f"Writing {cls_.__name__.lower()}s into database")
@@ -846,8 +916,14 @@ class PrePopulate:
data_upgrade_funcs: Mapping[Type, Callable] = {
cls.task_cls: cls._upgrade_task_data,
cls.model_cls: cls._upgrade_model_data,
cls.project_cls: cls._upgrade_project_data,
}
for item in cls.json_lines(f):
if old_to_new_ids:
for old_id, new_id in old_to_new_ids.items():
# replace ids only when they are standalone strings
# otherwise artifacts uris that contain old ids may get damaged
item = item.replace(f'"{old_id}"', f'"{new_id}"')
upgrade_func = data_upgrade_funcs.get(cls_)
if upgrade_func:
item = json.dumps(upgrade_func(json.loads(item)))
@@ -884,11 +960,13 @@ class PrePopulate:
return tasks
@classmethod
def _import_events(cls, f: IO[bytes], full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
for ev in events:
ev["task"] = task_id
ev["company_id"] = company_id
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
company_id, events=events, worker="", allow_locked=True
)

View File

@@ -53,6 +53,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
name=user_name,
given_name=given_name,
family_name=family_name,
created=datetime.utcnow(),
).save()
return user_id

View File

@@ -0,0 +1,12 @@
from pymongo.collection import Collection
from pymongo.database import Database
def migrate_backend(db: Database):
projects: Collection = db["project"]
for doc in projects.find({"basename": None}):
name: str = doc["name"]
_, _, basename = name.rpartition("/")
projects.update_one(
{"_id": doc["_id"]}, {"$set": {"basename": basename}},
)

View File

@@ -0,0 +1,15 @@
from pymongo.collection import Collection
from pymongo.database import Database
def migrate_backend(db: Database, auth_db: Database):
users: Collection = db["user"]
auth_users: Collection = auth_db["user"]
created_field = "created"
for doc in users.find({created_field: {"$exists": False}}):
auth_user = auth_users.find_one({"_id": doc["_id"]}, projection=[created_field])
if not auth_user or created_field not in auth_user:
continue
users.update_one(
{"_id": doc["_id"]}, {"$set": {created_field: auth_user[created_field]}}
)

View File

@@ -1,4 +1,4 @@
attrs>=19.1.0
attrs>=22.1.0
bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
@@ -21,14 +21,13 @@ nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0
pyjwt>=2.4.0
pymongo[srv]==3.12.0
python-rapidjson>=0.6.3
redis==3.5.3
redis-py-cluster>=2.1.3
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4
validators>=0.12.4

View File

@@ -262,6 +262,38 @@ get_credentials {
}
}
edit_credentials {
allow_roles = [ "*" ]
internal: false
"2.19" {
description: """Updates the label of the existing credentials for the authenticated user."""
request {
type: object
required: [ access_key ]
properties {
access_key {
type: string
description: Existing credentials key
}
label {
type: string
description: New credentials label
}
}
}
response {
type: object
properties {
updated {
description: "Number of credentials updated"
type: integer
enum: [0, 1]
}
}
}
}
}
revoke_credentials {
allow_roles = [ "*" ]
internal: false

View File

@@ -302,6 +302,48 @@ _definitions {
}
}
}
plots_response_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
iterations {
type: array
items {
type: object
properties {
iter {
type: integer
description: Iteration number
}
events {
type: array
items {
type: object
description: Plot event
}
}
}
}
}
}
}
plots_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
metrics {
type: array
description: "Plot events grouped by tasks and iterations"
items {"$ref": "#/definitions/plots_response_task_metrics"}
}
}
}
debug_image_sample_response {
type: object
properties {
@@ -323,6 +365,28 @@ _definitions {
}
}
}
plot_sample_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
}
events {
description: "Plot events"
type: array
items { type: object}
}
min_iteration {
type: integer
description: "minimal valid iteration for the metric"
}
max_iteration {
type: integer
description: "maximal valid iteration for the metric"
}
}
}
}
add {
"2.1" {
@@ -342,13 +406,27 @@ add {
additionalProperties: true
}
}
"2.22": ${add."2.1"} {
request.properties {
model_event {
type: boolean
description: If set then the event is for a model. Otherwise for a task. Cannot be used with task log events. If used in batch then all the events should be marked the same
default: false
}
allow_locked {
type: boolean
description: Allow adding events to published tasks or models
default: false
}
}
}
}
add_batch {
"2.1" {
description: "Adds a batch of events in a single call (json-lines format, stream-friendly)"
batch_request: {
action: add
version: 1.5
version: 2.1
}
response {
type: object
@@ -359,10 +437,16 @@ add_batch {
}
}
}
"2.22": ${add_batch."2.1"} {
batch_request: {
action: add
version: 2.22
}
}
}
delete_for_task {
"2.1" {
description: "Delete all task event. *This cannot be undone!*"
description: "Delete all task events. *This cannot be undone!*"
request {
type: object
required: [
@@ -391,6 +475,37 @@ delete_for_task {
}
}
}
delete_for_model {
"2.22" {
description: "Delete all model events. *This cannot be undone!*"
request {
type: object
required: [
model
]
properties {
model {
type: string
description: "Model ID"
}
allow_locked {
type: boolean
description: "Allow deleting events even if the model is locked"
default: false
}
}
}
response {
type: object
properties {
deleted {
type: boolean
description: "Number of deleted events"
}
}
}
}
}
debug_images {
"2.1" {
description: "Get all debug images of a task"
@@ -485,6 +600,55 @@ debug_images {
}
}
}
"2.22": ${debug_images."2.14"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
plots {
"2.20" {
description: "Get plot events for the requested amount of iterations per each task"
request {
type: object
required: [
metrics
]
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/task_metric_variants" }
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
navigate_earlier {
type: boolean
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest iterations. The default is False"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {"$ref": "#/definitions/plots_response"}
}
"2.22": ${plots."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
}
get_debug_image_sample {
"2.12": {
@@ -521,6 +685,20 @@ get_debug_image_sample {
}
response {"$ref": "#/definitions/debug_image_sample_response"}
}
"2.20": ${get_debug_image_sample."2.12"} {
request.properties.navigate_current_metric {
description: If set then subsequent navigation with next_debug_image_sample is done on the debug images for the passed metric only. Otherwise for all the metrics
type: boolean
default: true
}
}
"2.22": ${get_debug_image_sample."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model debug images. Otherwise task debug images
default: false
}
}
}
next_debug_image_sample {
"2.12": {
@@ -546,6 +724,99 @@ next_debug_image_sample {
}
response {"$ref": "#/definitions/debug_image_sample_response"}
}
"2.22": ${next_debug_image_sample."2.12"} {
request.properties.next_iteration {
type: boolean
default: false
description: If set then navigate to the next/previous iteration
}
model_events {
type: boolean
description: If set then the retrieving model debug images. Otherwise task debug images
default: false
}
}
}
get_plot_sample {
"2.20": {
description: "Return plots for the provided iteration"
request {
type: object
required: [task, metric]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
iteration {
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
type: integer
}
refresh {
description: "If set then scroll state will be refreshed to reflect the latest changes in the plots"
type: boolean
}
scroll_id {
type: string
description: "Scroll ID from the previous call to get_plot_sample or empty"
}
navigate_current_metric {
description: If set then subsequent navigation with next_plot_sample is done on the plots for the passed metric only. Otherwise for all the metrics
type: boolean
default: true
}
}
}
response {"$ref": "#/definitions/plot_sample_response"}
}
"2.22": ${get_plot_sample."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
}
next_plot_sample {
"2.20": {
description: "Get the plot for the next metric for the same iteration or for the next iteration"
request {
type: object
required: [task, scroll_id]
properties {
task {
description: "Task ID"
type: string
}
scroll_id {
type: string
description: "Scroll ID from the previous call to get_plot_sample"
}
navigate_earlier {
type: boolean
description: """If set then get the either previous metric events from the current iteration or (if does not exist) the last metric events from the previous iteration.
Otherwise next metric events from the current iteration or first metric events from the next iteration"""
}
}
}
response {"$ref": "#/definitions/plot_sample_response"}
}
"2.22": ${next_plot_sample."2.20"} {
request.properties.next_iteration {
type: boolean
default: false
description: If set then navigate to the next/previous iteration
}
model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
}
get_task_metrics{
"2.7": {
@@ -578,6 +849,13 @@ get_task_metrics{
}
}
}
"2.22": ${get_task_metrics."2.7"} {
model_events {
type: boolean
description: If set then get metrics from model events. Otherwise from task events
default: false
}
}
}
get_task_log {
"1.5" {
@@ -795,6 +1073,13 @@ get_task_events {
}
}
}
"2.22": ${get_task_events."2.1"} {
model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
}
}
download_task_log {
@@ -896,6 +1181,13 @@ get_task_plots {
default: false
}
}
"2.22": ${get_task_plots."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_multi_task_plots {
"2.1" {
@@ -953,6 +1245,13 @@ get_multi_task_plots {
default: false
}
}
"2.22": ${get_multi_task_plots."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@@ -981,6 +1280,13 @@ get_vector_metrics_and_variants {
}
}
}
"2.22": ${get_vector_metrics_and_variants."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
vector_metrics_iter_histogram {
"2.1" {
@@ -1019,6 +1325,13 @@ vector_metrics_iter_histogram {
}
}
}
"2.22": ${vector_metrics_iter_histogram."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
scalar_metrics_iter_histogram {
"2.1" {
@@ -1072,6 +1385,13 @@ scalar_metrics_iter_histogram {
}
}
}
"2.22": ${scalar_metrics_iter_histogram."2.14"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
multi_task_scalar_metrics_iter_histogram {
"2.1" {
@@ -1087,7 +1407,7 @@ multi_task_scalar_metrics_iter_histogram {
type: array
items {
type: string
description: "List of task Task IDs"
description: "Task ID"
}
}
samples {
@@ -1111,6 +1431,69 @@ multi_task_scalar_metrics_iter_histogram {
additionalProperties: true
}
}
"2.22": ${multi_task_scalar_metrics_iter_histogram."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_task_single_value_metrics {
"2.20" {
description: Get single value metrics for the passed tasks
request {
type: object
required: [tasks]
properties {
tasks {
description: "List of task Task IDs"
type: array
items {
type: string
description: "Task ID"
}
}
}
}
response {
type: object
properties {
tasks {
description: Single value metrics grouped by task
type: array
items {
type: object
properties {
task {
type: string
description: Task ID
}
values {
type: array
items {
type: object
properties {
metric { type: string }
variant { type: string}
value { type: number }
timestamp { type: number }
}
}
}
}
}
}
}
}
}
"2.22": ${get_task_single_value_metrics."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_task_latest_scalar_values {
"2.1" {
@@ -1190,6 +1573,13 @@ get_scalar_metrics_and_variants {
}
}
}
"2.22": ${get_scalar_metrics_and_variants."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_scalar_metric_data {
"2.1" {
@@ -1239,6 +1629,13 @@ get_scalar_metric_data {
default: false
}
}
"2.22": ${get_scalar_metric_data."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
scalar_metrics_iter_raw {
"2.16" {
@@ -1303,6 +1700,13 @@ scalar_metrics_iter_raw {
}
}
}
"2.22": ${scalar_metrics_iter_raw."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
clear_scroll {
"2.18" {
@@ -1324,4 +1728,38 @@ clear_scroll {
additionalProperties: false
}
}
}
clear_task_log {
"2.19" {
description: Remove old logs from task
request {
type: object
required: [task]
properties {
task {
description: Task ID
type: string
}
allow_locked {
type: boolean
description: Allow deleting events even if the task is locked
default: false
}
threshold_sec {
description: The amount of seconds ago to retain the log records. The older log records will be deleted. If not passed or 0 then all the log records for the task will be deleted
type: integer
}
}
}
response {
type: object
properties {
deleted {
description: The number of deleted log records
type: integer
}
}
}
}
}

View File

@@ -104,6 +104,16 @@ _definitions {
"$ref": "#/definitions/metadata_item"
}
}
stats {
description: "Model statistics"
type: object
properties {
labels_count {
description: Number of the model labels
type: integer
}
}
}
}
}
published_task_item {
@@ -224,6 +234,13 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"2.20": ${get_all_ex."2.15"} {
request.properties.include_stats {
description: "If true, include models statistic in response"
type: boolean
default: false
}
}
}
get_all {
"2.1" {
@@ -303,6 +320,14 @@ get_all {
type: array
items { type: string }
}
last_update {
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
type: array
items {
type: string
pattern: "^(>=|>|<=|<)?.*$"
}
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"

View File

@@ -102,4 +102,78 @@ get_user_companies {
}
}
}
}
}
get_entities_count {
"2.20": {
description: "Get counts for the company entities according to the passed search criteria"
request {
type: object
properties {
projects {
type: object
additionalProperties: true
description: Search criteria for projects
}
tasks {
type: object
additionalProperties: true
description: Search criteria for experiments
}
models {
type: object
additionalProperties: true
description: Search criteria for models
}
pipelines {
type: object
additionalProperties: true
description: Search criteria for pipelines
}
datasets {
type: object
additionalProperties: true
description: Search criteria for datasets
}
}
}
response {
type: object
properties {
projects {
type: integer
description: The number of projects matching the criteria
}
tasks {
type: integer
description: The number of experiments matching the criteria
}
models {
type: integer
description: The number of models matching the criteria
}
pipelines {
type: integer
description: The number of pipelines matching the criteria
}
datasets {
type: integer
description: The number of datasets matching the criteria
}
}
}
}
"2.22": ${get_entities_count."2.20"} {
request.properties {
search_hidden {
description: "If set to 'true' then hidden projects and tasks are included in the search results"
type: boolean
default: false
}
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
}
}
}

View File

@@ -25,6 +25,10 @@ _definitions {
description: "Project name"
type: string
}
basename {
description: "Project base name"
type: string
}
description {
description: "Project description"
type: string
@@ -42,11 +46,6 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
description: "User-defined tags"
type: array
@@ -156,6 +155,10 @@ _definitions {
description: "Project name"
type: string
}
basename {
description: "Project base name"
type: string
}
description {
description: "Project description"
type: string
@@ -173,11 +176,6 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
description: "User-defined tags"
type: array
@@ -192,6 +190,11 @@ _definitions {
description: "The default output destination URL for new tasks under this project"
type: string
}
last_update {
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
type: string
format: "date-time"
}
// extra properties
stats {
description: "Additional project stats"
@@ -214,6 +217,28 @@ _definitions {
}
}
}
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
dataset_stats {
description: Project dataset statistics
type: object
properties {
file_count {
type: integer
description: The number of files stored in the dataset
}
total_size {
type: integer
description: The total dataset size in bytes
}
}
}
}
}
metric_variant_result {
@@ -385,6 +410,10 @@ get_all {
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
type: string
}
basename {
description: "Project base name"
type: string
}
description {
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
type: string
@@ -530,18 +559,6 @@ get_all_ex {
}
}
}
response {
properties {
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
own_models {
description: "The amount of models under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
}
}
}
}
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
@@ -582,15 +599,16 @@ get_all_ex {
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
description: The filter for selecting entities that participate in statistics calculation. For each task field that you want to filter on pass the list of allowed values. Prepend the value with '-' to exclude
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
additionalProperties: true
}
}
"2.20": ${get_all_ex."2.17"} {
request.properties.include_dataset_stats {
description: "If true, include project dataset statistic in response"
type: boolean
default: false
}
}
}

View File

@@ -112,6 +112,12 @@ get_by_id {
}
}
}
"2.20": ${get_by_id."2.4"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
// typescript generation hack
get_all_ex {
@@ -140,6 +146,19 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"2.20": ${get_all_ex."2.15"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
"2.21": ${get_all_ex."2.20"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden queues are included in the search results"
type: boolean
default: false
}
}
}
get_all {
"2.4" {
@@ -226,6 +245,19 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.20": ${get_all."2.15"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
"2.21": ${get_all."2.20"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden queues are included in the search results"
type: boolean
default: false
}
}
}
get_default {
"2.4" {
@@ -431,6 +463,33 @@ get_next_task {
}
}
}
"2.14": ${get_next_task."2.4"} {
request.properties.get_task_info {
description: "If set then additional task info is returned"
type: boolean
default: false
}
response.properties.task_info {
description: "Info about the returned task. Returned only if get_task_info is set to True"
type: object
properties {
company {
description: Task company ID
type: string
}
user {
description: ID of the user who created the task
type: string
}
}
}
}
"2.21": ${get_next_task."2.14"} {
request.properties.task {
description: Task company ID
type: string
}
}
}
remove_task {
"2.4" {
@@ -634,6 +693,13 @@ get_queue_metrics : {
}
}
}
"2.20": ${get_queue_metrics."2.4"} {
request.properties.refresh {
type: boolean
default: false
description: If set then the new queue metrics is taken
}
}
}
add_or_update_metadata {
"2.13" {
@@ -700,3 +766,51 @@ delete_metadata {
}
}
}
peek_task {
"2.15" {
description: "Peek the next task from a given queue"
request {
type: object
required: [ queue ]
properties {
queue {
description: "ID of the queue"
type: string
}
}
}
response {
type: object
properties {
task {
description: "Task ID"
type: string
}
}
}
}
}
get_num_entries {
"2.15" {
description: "Get the number of task entries in the given queue"
request {
type: object
required: [ queue ]
properties {
queue {
description: "ID of the queue"
type: string
}
}
}
response {
type: object
properties {
num {
description: "Number of entries"
type: integer
}
}
}
}
}

View File

@@ -1489,6 +1489,13 @@ reset {
}
}
}
"2.21": ${reset."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the task from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
reset_many {
"2.13": ${_definitions.batch_operation} {
@@ -1541,6 +1548,13 @@ reset_many {
}
}
}
"2.21": ${reset_many."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the tasks from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
@@ -1591,6 +1605,13 @@ delete_many {
}
}
}
"2.21": ${delete_many."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the tasks from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
delete {
"2.1" {
@@ -1655,6 +1676,13 @@ delete {
}
}
}
"2.21": ${delete."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the task from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
archive {
"2.12" {
@@ -1898,7 +1926,7 @@ Fails if the following parameters in the task were not filled:
]
properties {
queue {
description: "Queue id. If not provided, task is added to the default queue."
description: "Queue id. If not provided and no queue name is passed then task is added to the default queue."
type: string
}
}
@@ -1914,6 +1942,23 @@ Fails if the following parameters in the task were not filled:
}
}
}
"2.19": ${enqueue."1.5"} {
request.properties.queue_name {
description: The name of the queue. If the queue does not exist then it is auto-created. Cannot be used together with the queue id
type: string
}
}
"2.22": ${enqueue."2.19"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
@@ -1922,7 +1967,7 @@ enqueue_many {
properties {
ids.description: "IDs of the tasks to enqueue"
queue {
description: "Queue id. If not provided, tasks are added to the default queue."
description: "Queue id. If not provided and no queue name is passed then tasks are added to the default queue."
type: string
}
validate_tasks {
@@ -1941,6 +1986,23 @@ enqueue_many {
}
}
}
"2.19": ${enqueue_many."2.13"} {
request.properties.queue_name {
description: The name of the queue. If the queue does not exist then it is auto-created. Cannot be used together with the queue id
type: string
}
}
"2.22": ${enqueue_many."2.19"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}
dequeue {
"1.5" {
@@ -2019,6 +2081,18 @@ completed {
} ${_references.status_change_request}
response: ${_definitions.update_response}
}
"2.20": ${completed."2.2"} {
request.properties.publish {
type: boolean
default: false
description: If set and the task is completed successfully then it is published
}
response.properties.published {
description: "Number of tasks published (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
ping {

View File

@@ -139,6 +139,22 @@ get_current_user {
}
}
}
"2.20": ${get_current_user."2.1"} {
response {
properties {
getting_started {
type: object
description: Getting stated info
additionalProperties: true
}
created {
type: string
description: User creation time
format: date-time
}
}
}
}
}
get_all_ex {

View File

@@ -1,152 +1,336 @@
{
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
metrics_category {
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
metrics_category {
type: object
properties {
name {
type: string
description: "Name of the metrics category."
}
metric_keys {
type: array
items { type: string }
description: "The names of the metrics in the category."
}
}
}
aggregation_type {
type: string
enum: [ avg, min, max ]
description: "Metric aggregation type"
}
stat_item {
type: object
properties {
key {
type: string
description: "Name of a metric"
}
category {
"$ref": "#/definitions/aggregation_type"
}
}
}
aggregation_stats {
type: object
properties {
aggregation {
"$ref": "#/definitions/aggregation_type"
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
items { type: number }
}
}
}
metric_stats {
type: object
properties {
metric {
type: string
description: "Name of the metric ("cpu_usage", "memory_used" etc.)"
}
variant {
type: string
description: "Name of the metric component. Set only if 'split_by_variant' was set in the request"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
stats {
type: array
description: "Statistics data by type"
items { "$ref": "#/definitions/aggregation_stats" }
}
}
}
worker_stats {
type: object
properties {
worker {
type: string
description: "ID of the worker"
}
metrics {
type: array
description: "List of the metrics statistics for the worker"
items { "$ref": "#/definitions/metric_stats" }
}
}
}
activity_series {
type: object
properties {
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
items {type: integer}
}
counts {
type: array
description: "List of worker counts corresponding to the timestamps in the dates list. None values are returned for the dates with no workers."
items {type: integer}
}
}
}
worker {
type: object
properties {
id {
description: "Worker ID"
type: string
}
user {
description: "Associated user (under whose credentials are used by the worker daemon)"
"$ref": "#/definitions/id_name_entry"
}
company {
description: "Associated company"
"$ref": "#/definitions/id_name_entry"
}
ip {
description: "IP of the worker"
type: string
}
register_time {
description: "Registration time"
type: string
format: "date-time"
}
last_activity_time {
description: "Last activity time (even if an error occurred)"
type: string
format: "date-time"
}
last_report_time {
description: "Last successful report time"
type: string
format: "date-time"
}
task {
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
}
queues {
description: "List of queues on which the worker is listening"
type: array
items { "$ref": "#/definitions/queue_entry" }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
key {
description: "Worker entry key"
type: string
}
}
}
id_name_entry {
type: object
properties {
id {
description: "Worker ID"
type: string
}
name {
description: "Worker name"
type: string
}
}
}
current_task_entry = ${_definitions.id_name_entry} {
properties {
running_time {
description: "Task running time"
type: integer
}
last_iteration {
description: "Last task iteration"
type: integer
}
}
}
queue_entry = ${_definitions.id_name_entry} {
properties {
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
}
num_tasks {
description: "Number of task entries in the queue"
type: integer
}
}
}
machine_stats {
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
}
memory_used {
description: "Used memory MBs"
type: integer
}
memory_free {
description: "Free memory MBs"
type: integer
}
gpu_memory_free {
description: "GPU free memory MBs"
type: array
items { type: integer }
}
gpu_memory_used {
description: "GPU used memory MBs"
type: array
items { type: integer }
}
network_tx {
description: "Mbytes per second"
type: integer
}
network_rx {
description: "Mbytes per second"
type: integer
}
disk_free_home {
description: "Mbytes free space of /home drive"
type: integer
}
disk_free_temp {
description: "Mbytes free space of /tmp drive"
type: integer
}
disk_read {
description: "Mbytes read per second"
type: integer
}
disk_write {
description: "Mbytes write per second"
type: integer
}
cpu_temperature {
description: "CPU temperature"
type: array
items { type: number }
}
gpu_temperature {
description: "GPU temperature"
type: array
items { type: number }
}
}
}
}
get_all {
"2.4" {
description: "Returns information on all registered workers."
request {
type: object
properties {
name {
type: string
description: "Name of the metrics category."
}
metric_keys {
type: array
items { type: string }
description: "The names of the metrics in the category."
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 3600
}
}
}
aggregation_type {
type: string
enum: [ avg, min, max ]
description: "Metric aggregation type"
}
stat_item {
response {
type: object
properties {
key {
type: string
description: "Name of a metric"
}
category {
"$ref": "#/definitions/aggregation_type"
workers {
type: array
items { "$ref": "#/definitions/worker" }
}
}
}
aggregation_stats {
type: object
properties {
aggregation {
"$ref": "#/definitions/aggregation_type"
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
items { type: number }
}
}
}
"2.20": ${get_all."2.4"} {
request.properties.tags {
description: The list of allowed worker tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
metric_stats {
type: object
properties {
metric {
type: string
description: "Name of the metric ("cpu_usage", "memory_used" etc.)"
}
variant {
type: string
description: "Name of the metric component. Set only if 'split_by_variant' was set in the request"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
stats {
type: array
description: "Statistics data by type"
items { "$ref": "#/definitions/aggregation_stats" }
}
}
}
"2.22": ${get_all."2.20"} {
request.properties.system_tags {
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
worker_stats {
}
}
register {
"2.4" {
description: "Register a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
type: string
description: "ID of the worker"
}
metrics {
type: array
description: "List of the metrics statistics for the worker"
items { "$ref": "#/definitions/metric_stats" }
}
}
}
activity_series {
type: object
properties {
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
items {type: integer}
}
counts {
type: array
description: "List of worker counts corresponding to the timestamps in the dates list. None values are returned for the dates with no workers."
items {type: integer}
}
}
}
worker {
type: object
properties {
id {
description: "Worker ID"
description: "Worker id. Must be unique in company."
type: string
}
user {
description: "Associated user (under whose credentials are used by the worker daemon)"
"$ref": "#/definitions/id_name_entry"
}
company {
description: "Associated company"
"$ref": "#/definitions/id_name_entry"
}
ip {
description: "IP of the worker"
type: string
}
register_time {
description: "Registration time"
type: string
format: "date-time"
}
last_activity_time {
description: "Last activity time (even if an error occurred)"
type: string
format: "date-time"
}
last_report_time {
description: "Last successful report time"
type: string
format: "date-time"
}
task {
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
timeout {
description: "Registration timeout in seconds. If timeout seconds have passed since the worker's last call to register or status_report, the worker is automatically removed from the list of registered workers."
type: integer
default: 600
}
queues {
description: "List of queues on which the worker is listening"
description: "List of queue IDs on which the worker is listening."
type: array
items { "$ref": "#/definitions/queue_entry" }
items { type: string }
}
tags {
description: "User tags for the worker"
@@ -155,348 +339,199 @@
}
}
}
id_name_entry {
response {
type: object
properties {}
}
}
"2.22": ${register."2.4"} {
request.properties.system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
}
}
unregister {
"2.4" {
description: "Unregister a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
id {
description: "ID"
type: string
}
name {
description: "Name"
worker {
description: "Worker id. Must be unique in company."
type: string
}
}
}
current_task_entry = ${_definitions.id_name_entry} {
properties {
running_time {
description: "Task running time"
type: integer
}
last_iteration {
description: "Last task iteration"
type: integer
}
}
response {
type: object
properties {}
}
queue_entry = ${_definitions.id_name_entry} {
properties {
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
}
num_tasks {
description: "Number of task entries in the queue"
type: integer
}
}
}
machine_stats {
}
}
status_report {
"2.4" {
description: "Called periodically by the worker daemon to report machine status"
request {
required: [
worker
timestamp
]
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
worker {
description: "Worker id."
type: string
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
task {
description: "ID of a task currently being run by the worker. If no task is sent, the worker's task field will be cleared."
type: string
}
memory_used {
description: "Used memory MBs"
queue {
description: "ID of the queue from which task was received. If no queue is sent, the worker's queue field will be cleared."
type: string
}
queues {
description: "List of queue IDs on which the worker is listening. If null, the worker's queues list will not be updated."
type: array
items { type: string }
}
timestamp {
description: "UNIX time in seconds since epoch."
type: integer
}
memory_free {
description: "Free memory MBs"
type: integer
machine_stats {
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
gpu_memory_free {
description: "GPU free memory MBs"
tags {
description: "New user tags for the worker"
type: array
items { type: integer }
items: { type: string }
}
gpu_memory_used {
description: "GPU used memory MBs"
}
}
response {
type: object
properties {}
}
}
"2.22": ${status_report."2.4"} {
request.properties.system_tags {
description: "New system tags for the worker"
type: array
items: { type: string }
}
}
}
get_metric_keys {
"2.4" {
description: "Returns worker statistics metric keys grouped by categories."
request {
type: object
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: integer }
items { type: string }
}
network_tx {
description: "Mbytes per second"
type: integer
}
network_rx {
description: "Mbytes per second"
type: integer
}
disk_free_home {
description: "Mbytes free space of /home drive"
type: integer
}
disk_free_temp {
description: "Mbytes free space of /tmp drive"
type: integer
}
disk_read {
description: "Mbytes read per second"
type: integer
}
disk_write {
description: "Mbytes write per second"
type: integer
}
cpu_temperature {
description: "CPU temperature"
}
}
response {
type: object
properties {
categories {
type: array
items { type: number }
description: "List of unique metric categories found in the statistics of the requested workers."
items { "$ref": "#/definitions/metrics_category" }
}
gpu_temperature {
description: "GPU temperature"
}
}
}
}
get_stats {
"2.4" {
description: "Returns statistics for the selected workers and time range aggregated by date intervals."
request {
type: object
required: [ from_date, to_date, interval, items ]
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: number }
items { type: string }
}
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
items {
description: "List of metric keys and requested statistics"
type: array
items { "$ref": "#/definitions/stat_item" }
}
split_by_variant {
description: "If true then break statistics by hardware sub types"
type: boolean
default: false
}
}
}
response {
type: object
properties {
workers {
type: array
description: "List of the requested workers with their statistics"
items { "$ref": "#/definitions/worker_stats" }
}
}
}
}
get_all {
"2.4" {
description: "Returns information on all registered workers."
request {
type: object
properties {
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 3600
}
}
get_activity_report {
"2.4" {
description: "Returns count of active company workers in the selected time range."
request {
type: object
required: [ from_date, to_date, interval ]
properties {
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
}
response {
type: object
properties {
workers {
type: array
items { "$ref": "#/definitions/worker" }
}
}
response {
type: object
properties {
total {
description: "Activity series that include all the workers that sent reports in the given time interval."
"$ref": "#/definitions/activity_series"
}
active {
description: "Activity series that include only workers that worked on a task in the given time interval."
"$ref": "#/definitions/activity_series"
}
}
}
}
register {
"2.4" {
description: "Register a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
description: "Worker id. Must be unique in company."
type: string
}
timeout {
description: "Registration timeout in seconds. If timeout seconds have passed since the worker's last call to register or status_report, the worker is automatically removed from the list of registered workers."
type: integer
default: 600
}
queues {
description: "List of queue IDs on which the worker is listening."
type: array
items { type: string }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
response {
type: object
properties {}
}
}
}
unregister {
"2.4" {
description: "Unregister a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
description: "Worker id. Must be unique in company."
type: string
}
}
}
response {
type: object
properties {}
}
}
}
status_report {
"2.4" {
description: "Called periodically by the worker daemon to report machine status"
request {
required: [
worker
timestamp
]
type: object
properties {
worker {
description: "Worker id."
type: string
}
task {
description: "ID of a task currently being run by the worker. If no task is sent, the worker's task field will be cleared."
type: string
}
queue {
description: "ID of the queue from which task was received. If no queue is sent, the worker's queue field will be cleared."
type: string
}
queues {
description: "List of queue IDs on which the worker is listening. If null, the worker's queues list will not be updated."
type: array
items { type: string }
}
timestamp {
description: "UNIX time in seconds since epoch."
type: integer
}
machine_stats {
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
tags {
description: "New user tags for the worker"
type: array
items: { type: string }
}
}
}
response {
type: object
properties {}
}
}
}
get_metric_keys {
"2.4" {
description: "Returns worker statistics metric keys grouped by categories."
request {
type: object
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
categories {
type: array
description: "List of unique metric categories found in the statistics of the requested workers."
items { "$ref": "#/definitions/metrics_category" }
}
}
}
}
}
get_stats {
"2.4" {
description: "Returns statistics for the selected workers and time range aggregated by date intervals."
request {
type: object
required: [ from_date, to_date, interval, items ]
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: string }
}
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
items {
description: "List of metric keys and requested statistics"
type: array
items { "$ref": "#/definitions/stat_item" }
}
split_by_variant {
description: "If true then break statistics by hardware sub types"
type: boolean
default: false
}
}
}
response {
type: object
properties {
workers {
type: array
description: "List of the requested workers with their statistics"
items { "$ref": "#/definitions/worker_stats" }
}
}
}
}
}
get_activity_report {
"2.4" {
description: "Returns count of active company workers in the selected time range."
request {
type: object
required: [ from_date, to_date, interval ]
properties {
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
}
}
response {
type: object
properties {
total {
description: "Activity series that include all the workers that sent reports in the given time interval."
"$ref": "#/definitions/activity_series"
}
active {
description: "Activity series that include only workers that worked on a task in the given time interval."
"$ref": "#/definitions/activity_series"
}
}
}
}
}
}
}

View File

@@ -6,6 +6,8 @@ from flask_compress import Compress
from flask_cors import CORS
from packaging.version import Version
from apiserver.bll.queue.queue_metrics import MetricsRefresher
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from apiserver.database import db
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config import info
@@ -26,7 +28,6 @@ from apiserver.service_repo import ServiceRepo
from apiserver.sync import distributed_lock
from apiserver.updates import check_updates_thread
from apiserver.utilities.env import get_bool
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@@ -119,6 +120,8 @@ class AppSequence:
def _start_worker(self):
check_updates_thread.start()
StatisticsReporter.start()
MetricsRefresher.start()
NonResponsiveTasksWatchdog.start()
def _on_worker_stop(self):
ThreadsManager.terminating = True
pass

View File

@@ -313,6 +313,7 @@ class APICall(DataContainer):
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
"X-Jwt-Payload": "",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@@ -692,6 +693,10 @@ class APICall(DataContainer):
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
if sep:
prefix, _, redact = headers[header].partition(sep)
else:
prefix = sep = ""
redact = headers[header]
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@@ -11,7 +11,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
@@ -88,9 +87,7 @@ def authorize_credentials(auth_data, service, action, call):
query = Q(id=fixed_user.user_id)
with TimingContext("mongo", "user_by_cred"), translate_errors_context(
"authorizing request"
):
with translate_errors_context("authorizing request"):
user = User.objects(query).first()
if not user:
raise errors.unauthorized.InvalidCredentials(
@@ -108,8 +105,7 @@ def authorize_credentials(auth_data, service, action, call):
}
)
with TimingContext("mongo", "company_by_id"):
company = Company.objects(id=user.company).only("id", "name").first()
company = Company.objects(id=user.company).only("id", "name").first()
if not company:
raise errors.unauthorized.InvalidCredentials("invalid user company")

View File

@@ -2,6 +2,8 @@ import jwt
from datetime import datetime, timedelta
from jwt.algorithms import get_default_algorithms
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model.auth import Role
@@ -9,22 +11,24 @@ from apiserver.database.model.auth import Role
from .auth_type import AuthType
from .payload import Payload
token_secret = config.get('secure.auth.token_secret')
token_secret = config.get("secure.auth.token_secret")
log = config.logger(__file__)
class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
def __init__(self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_):
def __init__(
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities)
AuthType.bearer_token, identity=identity, entities=entities
)
self.exp = exp
self.iat = iat
self.nbf = nbf
self._env = env or config.get('env', '<unknown>')
self._env = env or config.get("env", "<unknown>")
@property
def env(self):
@@ -65,7 +69,15 @@ class Token(Payload):
@classmethod
def decode(cls, encoded_token, verify=True):
return jwt.decode(encoded_token, token_secret, verify=verify)
options = (
{"verify_signature": False, "verify_exp": True} if not verify else None
)
return jwt.decode(
encoded_token,
token_secret,
algorithms=get_default_algorithms(),
options=options,
)
@classmethod
def from_encoded_token(cls, encoded_token, verify=True):
@@ -74,23 +86,24 @@ class Token(Payload):
token = Token.from_dict(decoded)
assert isinstance(token, Token)
if not token.identity:
raise errors.unauthorized.InvalidToken('token missing identity')
raise errors.unauthorized.InvalidToken("token missing identity")
return token
except Exception as e:
raise errors.unauthorized.InvalidToken('failed parsing token, %s' % e.args[0])
raise errors.unauthorized.InvalidToken(
"failed parsing token, %s" % e.args[0]
)
@classmethod
def create_encoded_token(cls, identity, expiration_sec=None, entities=None, **extra_payload):
def create_encoded_token(
cls, identity, expiration_sec=None, entities=None, **extra_payload
):
if identity.role not in (Role.system,):
# limit expiration time for all roles but an internal service
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
token = cls(
identity=identity,
entities=entities,
iat=now)
token = cls(identity=identity, entities=entities, iat=now)
if expiration_sec:
# add 'expiration' claim

View File

@@ -8,6 +8,7 @@ import jsonmodels.models
from apiserver.apierrors import APIError, errors
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
from .auth import Identity
@@ -38,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.17")
_max_version = PartialVersion("2.22")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -283,7 +284,8 @@ class ServiceRepo(object):
# In case call does not require authorization, parsing the identity.company might raise an exception
company = cls._get_company(call, endpoint)
ret = endpoint.func(call, company, call.data_model)
with translate_errors_context():
ret = endpoint.func(call, company, call.data_model)
# allow endpoints to return dict or model (instead of setting them explicitly on the call)
if ret is not None:

View File

@@ -14,6 +14,7 @@ from apiserver.apimodels.auth import (
RevokeCredentialsRequest,
EditUserReq,
CreateCredentialsRequest,
EditCredentialsRequest,
)
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL
@@ -122,6 +123,27 @@ def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@endpoint("auth.edit_credentials")
def edit_credentials(call: APICall, company_id: str, request: EditCredentialsRequest):
identity = call.identity
access_key = request.access_key
updated = User.objects(
id=identity.user,
company=company_id,
credentials__match={"key": access_key},
).update_one(set__credentials__S__label=request.label)
if not updated:
raise errors.bad_request.InvalidAccessKey(
"invalid user or invalid access key",
user=identity.user,
access_key=access_key,
company=company_id,
)
call.result.data = {"updated": updated}
@endpoint(
"auth.revoke_credentials",
request_data_model=RevokeCredentialsRequest,

View File

@@ -2,7 +2,7 @@ import itertools
import math
from collections import defaultdict
from operator import itemgetter
from typing import Sequence, Optional
from typing import Sequence, Optional, Union, Tuple
import attr
import jsonmodels.fields
@@ -12,32 +12,57 @@ from apiserver.apierrors import errors
from apiserver.apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEventsRequest,
MetricEventsResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
NextHistorySampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
ClearScrollRequest,
ClearTaskLogRequest,
SingleValueMetricsRequest,
GetVariantSampleRequest, GetMetricSamplesRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
event_bll = EventBLL()
model_bll = ModelBLL()
def _assert_task_or_model_exists(
company_id: str, task_ids: Union[str, Sequence[str]], model_events: bool
) -> Union[Sequence[Model], Sequence[Task]]:
if model_events:
return model_bll.assert_exists(
company_id,
task_ids,
allow_public=True,
only=("id", "name", "company", "company_origin"),
)
return task_bll.assert_exists(
company_id,
task_ids,
allow_public=True,
only=("id", "name", "company", "company_origin"),
)
@endpoint("events.add")
@@ -45,7 +70,7 @@ def add(call: APICall, company_id, _):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked
company_id, [data], call.worker, allow_locked=allow_locked
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -56,7 +81,12 @@ def add_batch(call: APICall, company_id, _):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
allow_locked=events[0].get("allow_locked", False),
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -223,12 +253,13 @@ def download_task_log(call, company_id, _):
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, EventType.metrics_vector
task_or_model.get_index_company(), task_id, EventType.metrics_vector
)
)
@@ -236,12 +267,13 @@ def get_vector_metrics_and_variants(call, company_id, _):
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, EventType.metrics_scalar
task_or_model.get_index_company(), task_id, EventType.metrics_scalar
)
)
@@ -253,13 +285,14 @@ def get_scalar_metrics_and_variants(call, company_id, _):
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task.get_index_company(), task_id, metric, variant
task_or_model.get_index_company(), task_id, metric, variant
)
call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations
@@ -284,11 +317,10 @@ def make_response(
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
def get_task_events(call, company_id, request: TaskEventsRequest):
def get_task_events(_, company_id, request: TaskEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events,
)[0]
key = ScalarKeyEnum.iter
@@ -320,7 +352,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
metric_variants=metric_variants,
)
@@ -334,7 +366,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
res = event_bll.events_iterator.get_task_events(
event_type=request.event_type,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
batch_size=batch_size,
key=ScalarKeyEnum.iter,
@@ -363,18 +395,20 @@ def get_scalar_metric_data(call, company_id, _):
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_or_model.get_index_company(),
task_id,
event_type=EventType.metrics_scalar,
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id,
no_scroll=no_scroll,
model_events=model_events,
)
call.result.data = dict(
@@ -396,7 +430,7 @@ def get_task_latest_scalar_values(call, company_id, _):
index_company, task_id
)
last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
company_id=index_company, event_type=EventType.all, task_id=task_id, iters=1
).get(task_id)
call.result.data = dict(
metrics=metrics,
@@ -415,11 +449,11 @@ def get_task_latest_scalar_values(call, company_id, _):
def scalar_metrics_iter_histogram(
call, company_id, request: ScalarMetricsIterHistogramRequest
):
task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company", "company_origin")
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.get_index_company(),
company_id=task_or_model.get_index_company(),
task_id=request.task,
samples=request.samples,
key=request.key,
@@ -427,50 +461,84 @@ def scalar_metrics_iter_histogram(
call.result.data = metrics
def _get_task_or_model_index_company(
company_id: str, task_ids: Sequence[str], model_events=False,
) -> Tuple[str, Sequence[Task]]:
"""
Verify that all tasks exists and belong to store data in the same company index
Return company and tasks
"""
tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events,
)
unique_ids = set(task_ids)
if len(tasks_or_models) < len(unique_ids):
invalid = tuple(unique_ids - {t.id for t in tasks_or_models})
error_cls = (
errors.bad_request.InvalidModelId
if model_events
else errors.bad_request.InvalidTaskId
)
raise error_cls(company=company_id, ids=invalid)
companies = {t.get_index_company() for t in tasks_or_models}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
return companies.pop(), tasks_or_models
@endpoint(
"events.multi_task_scalar_metrics_iter_histogram",
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
)
def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
call, company_id, request: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
task_ids = request.tasks
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids, request.model_events
)
call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id,
task_ids=task_ids,
samples=req_model.samples,
allow_public=True,
key=req_model.key,
company_id=company,
tasks=tasks_or_models,
samples=request.samples,
key=request.key,
)
)
@endpoint("events.get_task_single_value_metrics")
def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest
):
company, tasks_or_models = _get_task_or_model_index_company(
company_id, request.tasks, request.model_events
)
res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()]
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=company_id,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
next(iter(companies)),
company,
task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
@@ -478,7 +546,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
scroll_id=scroll_id,
)
tasks = {t.id: t.name for t in tasks}
tasks = {t.id: t.name for t in tasks_or_models}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
@@ -493,36 +561,29 @@ def get_multi_task_plots_v1_7(call, company_id, _):
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, req_model):
def get_multi_task_plots(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids, model_events
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.get_task_events(
next(iter(companies)),
company,
task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
model_events=model_events,
)
tasks = {t.id: t.name for t in tasks}
tasks = {t.id: t.name for t in tasks_or_models}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
@@ -589,17 +650,18 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events
)[0]
result = event_bll.get_task_plots(
task.get_index_company(),
task_or_model.get_index_company(),
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
)
return_events = result.events
@@ -612,6 +674,46 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
)
@endpoint(
"events.plots",
request_data_model=MetricEventsRequest,
response_data_model=MetricEventsResponse,
)
def task_plots(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
company, _ = _get_task_or_model_index_company(
company_id, task_ids=list(task_metrics), model_events=request.model_events
)
result = event_bll.plots_iterator.get_task_events(
company_id=company,
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,
state_id=request.scroll_id,
)
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, iterations) in result.metric_events
],
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
@@ -654,17 +756,19 @@ def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
tasks_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
tasks_or_model.get_index_company(),
task_id,
event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
model_events=model_events,
)
return_events = result.events
@@ -681,10 +785,10 @@ def get_debug_images_v1_8(call, company_id, _):
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
request_data_model=MetricEventsRequest,
response_data_model=MetricEventsResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
def get_debug_images(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
@@ -692,21 +796,12 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
if None in metrics:
metrics.clear()
tasks = task_bll.assert_exists(
company_id,
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
company, _ = _get_task_or_model_index_company(
company_id, task_ids=list(task_metrics), model_events=request.model_events
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.debug_images_iterator.get_task_events(
company_id=next(iter(companies)),
company_id=company,
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
@@ -714,7 +809,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
state_id=request.scroll_id,
)
call.result.data_model = DebugImageResponse(
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
@@ -732,20 +827,21 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint(
"events.get_debug_image_sample",
min_version="2.12",
request_data_model=GetDebugImageSampleRequest,
request_data_model=GetVariantSampleRequest,
)
def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.debug_sample_history.get_debug_image_for_variant(
company_id=task.company,
res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task_or_model.get_index_company(),
task=request.task,
metric=request.metric,
variant=request.variant,
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
navigate_current_metric=request.navigate_current_metric,
)
call.result.data = attr.asdict(res, recurse=False)
@@ -753,31 +849,65 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
@endpoint(
"events.next_debug_image_sample",
min_version="2.12",
request_data_model=NextDebugImageSampleRequest,
request_data_model=NextHistorySampleRequest,
)
def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.debug_sample_history.get_next_debug_image(
company_id=task.company,
res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(),
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint(
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
)
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_samples_for_metric(
company_id=task_or_model.get_index_company(),
task=request.task,
metric=request.metric,
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
navigate_current_metric=request.navigate_current_metric,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint(
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
)
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(),
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id,
task_ids=request.tasks,
allow_public=True,
only=("company", "company_origin"),
)[0]
res = event_bll.metrics.get_tasks_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
company, _ = _get_task_or_model_index_company(
company_id, request.tasks, model_events=request.model_events
)
res = event_bll.metrics.get_task_metrics(
company, task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
@@ -785,7 +915,7 @@ def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
@@ -797,6 +927,34 @@ def delete_for_task(call, company_id, req_model):
)
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, model_id, allow_locked=allow_locked, model=True
)
)
@endpoint("events.clear_task_log")
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,
task_id=task_id,
allow_locked=request.allow_locked,
threshold_sec=request.threshold_sec,
)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter("metric", "variant", "task", "iter")
@@ -876,17 +1034,13 @@ def scalar_metrics_iter_raw(
request.batch_size = request.batch_size or scroll.request.batch_size
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
)[0]
task_or_model = _assert_task_or_model_exists(company_id, task_id, model_events=request.model_events)[0]
metric_variants = _get_metric_variants_from_request([request.metric])
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
metric_variants=metric_variants,
)
@@ -902,7 +1056,7 @@ def scalar_metrics_iter_raw(
for iteration in range(0, math.ceil(batch_size / 10_000)):
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
batch_size=min(batch_size, 10_000),
navigate_earlier=False,

View File

@@ -20,10 +20,11 @@ from apiserver.apimodels.models import (
AddOrUpdateMetadataRequest,
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
)
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
@@ -50,8 +51,8 @@ from apiserver.services.utils import (
ModelsBackwardsCompatibility,
unescape_metadata,
escape_metadata,
process_include_subprojects,
)
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -106,32 +107,31 @@ def get_by_task_id(call: APICall, company_id, _):
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
_process_include_subprojects(call.data)
process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
for model in models:
model["stats"] = stats.get(model["id"])
call.result.data = {"models": models, **ret_params}
@@ -139,10 +139,9 @@ def get_all_ex(call: APICall, company_id, _):
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models}
@@ -152,15 +151,14 @@ def get_by_id_ex(call: APICall, company_id, _):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@@ -399,9 +397,7 @@ def validate_task(company_id, fields: dict):
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
@@ -409,11 +405,7 @@ def edit(call: APICall, company_id, _):
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
if field and isinstance(value, dict) and isinstance(field, EmbeddedDocument):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
@@ -433,13 +425,9 @@ def edit(call: APICall, company_id, _):
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
_update_cached_tags(company_id, project=model.project, fields=fields)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -450,9 +438,7 @@ def edit(call: APICall, company_id, _):
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)

View File

@@ -1,13 +1,21 @@
from collections import defaultdict
from operator import itemgetter
from typing import Mapping, Type
from apiserver.apimodels.organization import TagsRequest
from mongoengine import Q
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.bll.project import ProjectBLL
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
org_bll = OrgBLL()
project_bll = ProjectBLL()
@endpoint("organization.get_tags", request_data_model=TagsRequest)
@@ -41,3 +49,47 @@ def get_user_companies(call: APICall, company_id: str, _):
}
]
}
@endpoint("organization.get_entities_count")
def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
entity_classes: Mapping[str, Type[AttributedDocument]] = {
"projects": Project,
"tasks": Task,
"models": Model,
"pipelines": Project,
"datasets": Project,
}
ret = {}
for field, entity_cls in entity_classes.items():
data = call.data.get(field)
if data is None:
continue
if request.active_users:
if entity_cls is Project:
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
ids, _ = project_bll.get_projects_with_active_user(
company=company,
users=request.active_users,
project_ids=requested_ids,
allow_public=True,
)
if not ids:
ret[field] = 0
continue
data["id"] = ids
elif not data.get("user"):
data["user"] = request.active_users
query = Q()
if entity_cls in (Project, Task) and not request.search_hidden:
query &= Q(system_tags__ne=EntityVisibility.hidden.value)
ret[field] = entity_cls.get_count(
company=company, query_dict=data, query=query, allow_public=True,
)
call.result.data = ret

View File

@@ -39,7 +39,6 @@ from apiserver.services.utils import (
get_tags_filter_dictionary,
sort_tags_response,
)
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
project_bll = ProjectBLL()
@@ -60,11 +59,8 @@ def get_by_id(call):
project_id = call.data["project"]
with translate_errors_context():
with TimingContext("mongo", "projects_by_id"):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
project = Project.objects(query).first()
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@@ -106,63 +102,70 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, data)
allow_public = not request.non_public
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
with TimingContext("mongo", "projects_get_all"):
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
)
if not ids:
return {"projects": []}
data["id"] = ids
ret_params = {}
projects: Sequence[dict] = Project.get_many_with_join(
user_active_project_ids = None
if request.active_users:
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
ret_params=ret_params,
)
if not ids:
return {"projects": []}
data["id"] = ids
if request.check_own_contents and requested_ids:
existing_requested_ids = {
project["id"] for project in projects if project["id"] in requested_ids
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
ret_params = {}
projects: Sequence[dict] = Project.get_many_with_join(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
allow_public=allow_public,
ret_params=ret_params,
)
if not projects:
return {"projects": projects, **ret_params}
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects, **ret_params}
return
project_ids = list({project["id"] for project in projects})
if request.check_own_contents:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=project_ids,
filter_=request.include_stats_filter,
users=request.active_users,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
project_ids = {project["id"] for project in projects}
conform_output_tags(call, projects)
if request.include_stats:
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=list(project_ids),
project_ids=project_ids,
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
return_hidden_children=request.search_hidden,
search_hidden=request.search_hidden,
filter_=request.include_stats_filter,
users=request.active_users,
user_active_project_ids=user_active_project_ids,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
call.result.data = {"projects": projects, **ret_params}
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id, project_ids=project_ids, users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
call.result.data = {"projects": projects, **ret_params}
@endpoint("projects.get_all")
@@ -172,20 +175,19 @@ def get_all(call: APICall):
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
)
with TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}
@endpoint(
@@ -271,6 +273,7 @@ def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
def delete(call: APICall, company_id: str, request: DeleteRequest):
res, affected_projects = delete_project(
company=company_id,
user=call.identity.user,
project_id=request.project,
force=request.force,
delete_contents=request.delete_contents,

View File

@@ -1,3 +1,5 @@
from mongoengine import Q
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.queues import (
GetDefaultResp,
@@ -14,10 +16,13 @@ from apiserver.apimodels.queues import (
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
GetNextTaskRequest,
GetByIdRequest,
GetAllRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
@@ -33,9 +38,11 @@ worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=QueueRequest)
def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
@@ -48,14 +55,34 @@ def get_by_id(call: APICall):
call.result.data_model = GetDefaultResp(id=queue.id, name=queue.name)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden queues search condition (unless specifically specified differently)
"""
hidden_tags = config.get("services.queues.hidden_tags", [])
if (
not hidden_tags
or data.get("search_hidden")
or data.get("id")
or data.get("name")
):
return Q()
return Q(system_tags__nin=hidden_tags)
@endpoint("queues.get_all_ex", min_version="2.4")
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
company_id=company,
query_dict=call.data,
query=_hidden_query(call.data),
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
@@ -63,12 +90,16 @@ def get_all_ex(call: APICall):
@endpoint("queues.get_all", min_version="2.4")
def get_all(call: APICall):
def get_all(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
company_id=company,
query_dict=call.data,
query=_hidden_query(call.data),
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
@@ -126,13 +157,13 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=req_model.queue
company_id=company_id, queue_id=request.queue, task_id=request.task
)
if entry:
data = {"entry": entry.to_proper_dict()}
if req_model.get_task_info:
if request.get_task_info:
task = Task.objects(id=entry.task).first()
if task:
data["task_info"] = {"company": task.company, "user": task.user}
@@ -224,14 +255,15 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
response_data_model=GetMetricsResponse,
)
def get_queue_metrics(
call: APICall, company_id, req_model: GetMetricsRequest
call: APICall, company_id, request: GetMetricsRequest
) -> GetMetricsResponse:
ret = queue_bll.metrics.get_queue_metrics(
company_id=company_id,
from_date=req_model.from_date,
to_date=req_model.to_date,
interval=req_model.interval,
queue_ids=req_model.queue_ids,
from_date=request.from_date,
to_date=request.to_date,
interval=request.interval,
queue_ids=request.queue_ids,
refresh=request.refresh,
)
queue_dicts = {
@@ -273,3 +305,17 @@ def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataReque
queue_id = request.queue
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
@endpoint("queues.peek_task", min_version="2.15")
def peek_task(call: APICall, company_id: str, request: QueueRequest):
queue_id = request.queue
queue = queue_bll.get_by_id(
company_id=company_id, queue_id=queue_id, max_task_entries=1
)
return {"task": queue.entries[0].task if queue.entries else None}
@endpoint("queues.get_num_entries", min_version="2.15")
def get_num_entries(call: APICall, company_id: str, request: QueueRequest):
return {"num": queue_bll.count_entries(company=company_id, queue_id=request.queue)}

View File

@@ -62,11 +62,13 @@ from apiserver.apimodels.tasks import (
DequeueManyResponse,
ResetManyResponse,
ResetBatchItem,
CompletedRequest,
CompletedResponse,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
@@ -79,7 +81,6 @@ from apiserver.bll.task.artifacts import (
Artifacts,
)
from apiserver.bll.task.hyperparams import HyperParams
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from apiserver.bll.task.param_utils import (
params_prepare_for_save,
params_unprepare_from_saved,
@@ -94,6 +95,7 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
@@ -116,8 +118,9 @@ from apiserver.services.utils import (
DockerCmdBackwardsCompatibility,
escape_dict_field,
unescape_dict_field,
process_include_subprojects,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.partial_version import PartialVersion
task_fields = set(Task.get_fields())
@@ -131,8 +134,6 @@ queue_bll = QueueBLL()
org_bll = OrgBLL()
project_bll = ProjectBLL()
NonResponsiveTasksWatchdog.start()
def set_task_status_from_call(
request: UpdateRequest, company_id, new_status=None, **set_fields
@@ -203,17 +204,6 @@ def escape_execution_parameters(call: APICall) -> dict:
return call_data
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
@@ -230,16 +220,15 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@@ -250,10 +239,9 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@@ -265,16 +253,15 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all"):
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@@ -486,12 +473,11 @@ def prepare_create_fields(
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_call"):
):
fields = prepare_create_fields(call, **kwargs)
task = task_bll.create(call, fields)
with TimingContext("code", "validate"):
task_bll.validate(task)
task_bll.validate(task)
return task, fields
@@ -524,7 +510,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
def create(call: APICall, company_id, req_model: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(), TimingContext("mongo", "save_task"):
with translate_errors_context():
task.save()
_update_cached_tags(company_id, project=task.project, fields=fields)
update_project_time(task.project)
@@ -707,7 +693,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_and_validate"):
):
fields = prepare_create_fields(
call, valid_fields=edit_fields, output=task.output, previous_task=task
)
@@ -861,8 +847,14 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
)
if request.verify_watched_queue:
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
res["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
call.result.data_model = EnqueueResponse(queued=queued, **res)
@@ -879,16 +871,25 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
validate=request.validate_tasks,
),
ids=request.ids,
)
extra = {}
if request.verify_watched_queue and results:
_id, (queued, res) = results[0]
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
call.result.data_model = EnqueueManyResponse(
succeeded=[
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
for _id, (queued, res) in results
],
failed=failures,
**extra,
)
@@ -938,10 +939,12 @@ def reset(call: APICall, company_id, request: ResetRequest):
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
)
res = ResetResponse(**updates, dequeued=dequeued)
@@ -964,10 +967,12 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
func=partial(
reset_task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
@@ -1065,14 +1070,18 @@ def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@@ -1083,17 +1092,23 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))
@@ -1152,11 +1167,11 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
@endpoint(
"tasks.completed",
min_version="2.2",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
request_data_model=CompletedRequest,
response_data_model=CompletedResponse,
)
def completed(call: APICall, company_id, request: PublishRequest):
call.result.data_model = UpdateResponse(
def completed(call: APICall, company_id, request: CompletedRequest):
res = CompletedResponse(
**set_task_status_from_call(
request,
company_id,
@@ -1165,6 +1180,22 @@ def completed(call: APICall, company_id, request: PublishRequest):
)
)
if res.updated and request.publish:
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
status_message=request.status_message,
)
res.published = publish_res.get("updated")
new_status = nested_get(publish_res, ("fields", "status"))
if new_status:
res.fields["status"] = new_status
call.result.data_model = res
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):

View File

@@ -95,26 +95,30 @@ def get_all(call: APICall, company_id, _):
@endpoint("users.get_current_user")
def get_current_user(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
projection = (
{"company.name"}
.union(User.get_fields())
.difference(User.get_exclude_fields())
)
res = User.get_many_with_join(
query=Q(id=call.identity.user),
company=company_id,
override_projection=projection,
)
user_id = call.identity.user
if not res:
raise errors.bad_request.InvalidUser("failed loading user")
projection = (
{"company.name"}
.union(User.get_fields())
.difference(User.get_exclude_fields())
)
res = User.get_many_with_join(
query=Q(id=user_id),
company=company_id,
override_projection=projection,
)
user = res[0]
user["role"] = call.identity.role
if not res:
raise errors.bad_request.InvalidUser("failed loading user")
resp = {"user": user}
call.result.data = resp
user = res[0]
user["role"] = call.identity.role
resp = {
"user": user,
"getting_started": config.get("apiserver.getting_started_info", None),
}
call.result.data = resp
create_fields = {

View File

@@ -3,6 +3,7 @@ from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.organization import Filter
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags
@@ -12,6 +13,17 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.utilities.partial_version import PartialVersion
def process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_:
return {}

View File

@@ -41,7 +41,12 @@ worker_bll = WorkerBLL()
)
def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection(company_id, request.last_seen)
workers=worker_bll.get_all_with_projection(
company_id,
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
)
)
@@ -64,6 +69,7 @@ def register(call: APICall, company_id, request: RegisterRequest):
queues=queues,
timeout=timeout,
tags=request.tags,
system_tags=request.system_tags,
)
@@ -72,7 +78,9 @@ def unregister(call: APICall, company_id, req_model: WorkerRequest):
worker_bll.unregister_worker(company_id, call.identity.user, req_model.worker)
@endpoint("workers.status_report", min_version="2.4", request_data_model=StatusReportRequest)
@endpoint(
"workers.status_report", min_version="2.4", request_data_model=StatusReportRequest
)
def status_report(call: APICall, company_id, request: StatusReportRequest):
worker_bll.status_report(
company_id=company_id,
@@ -80,6 +88,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
ip=call.real_ip,
report=request,
tags=request.tags,
system_tags=request.system_tags,
)

View File

@@ -1,4 +0,0 @@
{
api_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}

View File

@@ -7,9 +7,6 @@ class TestBatchOperations(TestService):
comment = "this is a comment"
delete_params = dict(can_fail=True, force=True)
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_tasks(self):
tasks = [self._temp_task() for _ in range(2)]
models = [
@@ -20,7 +17,7 @@ class TestBatchOperations(TestService):
ids = [*tasks, missing_id]
# enqueue
res = self.api.tasks.enqueue_many(ids=ids)
res = self.api.tasks.enqueue_many(ids=ids, queue_name="test batch")
self._assert_succeeded(res, tasks)
self._assert_failed(res, [missing_id])
data = self.api.tasks.get_all_ex(id=ids).tasks

View File

@@ -77,7 +77,7 @@ class TestModelsService(TestService):
def test_publish_output_model_no_task(self):
model_id = self.create_temp(
service="models", name='test', uri='file:///a', labels={}, ready=False
service="models", name="test", uri="file:///a", labels={}, ready=False
)
self._assert_model_ready(model_id, False)
@@ -109,7 +109,7 @@ class TestModelsService(TestService):
def test_publish_task_no_output_model(self):
task_id = self.create_temp(
service="tasks", type='testing', name='server-test', input=dict(view={})
service="tasks", type="testing", name="server-test", input=dict(view={})
)
self.api.tasks.started(task=task_id)
self.api.tasks.stopped(task=task_id)
@@ -118,31 +118,48 @@ class TestModelsService(TestService):
assert res.updated == 1 # model updated
self._assert_task_status(task_id, PUBLISHED)
def test_get_models_stats(self):
model1 = self._create_model(labels={"hello": 1, "world": 2})
model2 = self._create_model(labels={"foo": 1})
model3 = self._create_model()
# no stats
res = self.api.models.get_all_ex(id=[model1, model2, model3]).models
self.assertEqual(len(res), 3)
self.assertTrue(all("stats" not in m for m in res))
# stats
res = self.api.models.get_all_ex(
id=[model1, model2, model3], include_stats=True
).models
self.assertEqual(len(res), 3)
stats = {m.id: m.stats.labels_count for m in res}
self.assertEqual(stats[model1], 2)
self.assertEqual(stats[model2], 1)
self.assertEqual(stats[model3], 0)
def test_update_model_iteration_with_task(self):
task_id = self._create_task()
model_id = self._create_model()
self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1})
self.api.models.update(
model=model_id, task=task_id, iteration=1000, labels={"foo": 1}
)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
self.api.models.update(model=model_id, task=task_id, iteration=500)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
def test_update_model_for_task_iteration(self):
task_id = self._create_task()
res = self.api.models.update_for_task(
task=task_id,
name="test model",
uri="file:///b",
iteration=999,
task=task_id, name="test model", uri="file:///b", iteration=999,
)
model_id = res.id
@@ -150,22 +167,19 @@ class TestModelsService(TestService):
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
999
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999
)
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
def test_get_frameworks(self):
@@ -238,8 +252,8 @@ class TestModelsService(TestService):
return self.create_temp(
service="models",
delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", 'test'),
uri=kwargs.pop("name", 'file:///a'),
name=kwargs.pop("name", "test"),
uri=kwargs.pop("name", "file:///a"),
labels=kwargs.pop("labels", {}),
**kwargs,
)
@@ -247,8 +261,8 @@ class TestModelsService(TestService):
def _create_task(self, **kwargs):
task_id = self.create_temp(
service="tasks",
type=kwargs.pop("type", 'testing'),
name=kwargs.pop("name", 'server-test'),
type=kwargs.pop("type", "testing"),
name=kwargs.pop("name", "server-test"),
input=kwargs.pop("input", dict(view={})),
**kwargs,
)
@@ -257,21 +271,18 @@ class TestModelsService(TestService):
def _create_task_and_model(self):
execution_model_id = self.create_temp(
service="models",
name='test',
uri='file:///a',
labels={}
service="models", name="test", uri="file:///a", labels={}
)
task_id = self.create_temp(
service="tasks",
type='testing',
name='server-test',
type="testing",
name="server-test",
input=dict(view={}),
execution=dict(model=execution_model_id)
execution=dict(model=execution_model_id),
)
self.api.tasks.started(task=task_id)
output_model_id = self.api.models.update_for_task(
task=task_id, uri='file:///b'
task=task_id, uri="file:///b"
)["id"]
return task_id, output_model_id

View File

@@ -1,4 +1,4 @@
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apierrors.errors.bad_request import InvalidProjectId, ExpectedUniqueData
from apiserver.apierrors.errors.forbidden import NoWritePermission
from apiserver.config_repo import config
from apiserver.tests.automated import TestService
@@ -32,3 +32,12 @@ class TestProjectsEdit(TestService):
res = self.api.projects.get_all(id=[p1])
self.assertEqual([p.id for p in res.projects], [p1])
self.api.projects.update(project=p1, name="Test public change 2")
def test_project_name_uniqueness(self):
name1 = "Test name1"
p1 = self.create_temp("projects", name=name1, description="test")
with self.api.raises(ExpectedUniqueData):
p2 = self.create_temp("projects", name=name1, description="test")
p2 = self.create_temp("projects", name="Test name2", description="test")
with self.api.raises(ExpectedUniqueData):
self.api.projects.update(project=p2, name=name1)

View File

@@ -9,9 +9,6 @@ from apiserver.tests.automated import TestService, utc_now_tz_aware
class TestQueues(TestService):
def setUp(self, version="2.4"):
super().setUp(version=version)
def test_default_queue(self):
res = self.api.queues.get_default()
self.assertIsNotNone(res.id)
@@ -24,7 +21,7 @@ class TestQueues(TestService):
def test_queue_metrics(self):
queue_id = self._temp_queue("TestTempQueue")
task1 = self._create_temp_queued_task("temp task 1", queue_id)
self._create_temp_queued_task("temp task 1", queue_id)
time.sleep(1)
task2 = self._create_temp_queued_task("temp task 2", queue_id)
self.api.queues.get_next_task(queue=queue_id)
@@ -39,6 +36,27 @@ class TestQueues(TestService):
)
self.assertMetricQueues(res["queues"], queue_id)
def test_hidden_queues(self):
hidden_name = "TestHiddenQueue"
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])
non_hidden_queue = self._temp_queue("TestNonHiddenQueue")
queues = self.api.queues.get_all_ex().queues
ids = {q.id for q in queues}
self.assertFalse(hidden_queue in ids)
self.assertTrue(non_hidden_queue in ids)
queues = self.api.queues.get_all_ex(search_hidden=True).queues
ids = {q.id for q in queues}
self.assertTrue(hidden_queue in ids)
self.assertTrue(non_hidden_queue in ids)
queues = self.api.queues.get_all_ex(name=f"^{hidden_name}$").queues
self.assertEqual(hidden_queue, queues[0].id)
queues = self.api.queues.get_all_ex(id=[hidden_queue]).queues
self.assertEqual(hidden_queue, queues[0].id)
def test_reset_task(self):
queue = self._temp_queue("TestTempQueue")
task = self._temp_task("TempTask", is_development=True)
@@ -63,6 +81,47 @@ class TestQueues(TestService):
self.assertQueueTasks(res.queue, [task])
self.assertTaskTags(task, system_tags=[])
def test_dequeue_from_deleted_queue(self):
queue = self._temp_queue("TestTempQueue")
task_name = "TempDevTask"
task = self._temp_task(task_name)
self.api.tasks.enqueue(task=task, queue=queue)
res = self.api.tasks.get_by_id(task=task)
self.assertEqual(res.task.status, "queued")
self.api.queues.delete(queue=queue, force=True)
res = self.api.tasks.get_by_id(task=task)
self.assertEqual(res.task.status, "created")
def test_max_queue_entries(self):
queue = self._temp_queue("TestTempQueue")
tasks = [
self._create_temp_queued_task(t, queue)["id"]
for t in ("temp task1", "temp task2", "temp task3")
]
num = self.api.queues.get_num_entries(queue=queue).num
self.assertEqual(num, 3)
task_id = self.api.queues.peek_task(queue=queue).task
self.assertEqual(task_id, tasks[0])
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
res = self.api.queues.get_all(id=[queue]).queues[0]
self.assertQueueTasks(res, tasks)
res = self.api.queues.get_all(id=[queue], max_task_entries=2).queues[0]
self.assertQueueTasks(res, tasks[:2])
res = self.api.queues.get_all_ex(id=[queue]).queues[0]
self.assertEqual([e.task.id for e in res.entries], tasks)
res = self.api.queues.get_all_ex(id=[queue], max_task_entries=2).queues[0]
self.assertEqual([e.task.id for e in res.entries], tasks[:2])
def test_move_task(self):
queue = self._temp_queue("TestTempQueue")
tasks = [
@@ -168,8 +227,8 @@ class TestQueues(TestService):
sorted(queue.workers, key=sort_key), sorted(workers, key=sort_key)
)
def _temp_queue(self, queue_name, tags=None):
return self.create_temp("queues", name=queue_name, tags=tags)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)
def _temp_task(self, task_name, is_testing=False, is_development=False):
task_input = dict(

View File

@@ -12,6 +12,27 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count(
datasets={"system_tags": ["dataset"]}
)
self.assertEqual(res.datasets, 1)
task = self._temp_task(project=project)
data = self.api.projects.get_all_ex(
id=[project], include_dataset_stats=True
).projects[0]
self.assertIsNone(data.dataset_stats)
self.api.tasks.edit(
task=task, runtime={"ds_file_count": 2, "ds_total_size": 1000}
)
data = self.api.projects.get_all_ex(
id=[project], include_dataset_stats=True
).projects[0]
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
@@ -20,11 +41,12 @@ class TestSubProjects(TestService):
base_url=f"http://localhost:8008/v2.13",
)
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
basename = "Pr1"
child = self._temp_project(name=f"Aggregation/{basename}", client=user2_client)
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
self.assertEqual(child_project.parent.id, project)
self.assertEqual(child_project.basename, basename)
user = self.api.users.get_current_user().user.id
# test aggregations on project with empty subprojects
@@ -38,22 +60,37 @@ class TestSubProjects(TestService):
self.assertEqual(res.types, [])
res = self.api.projects.get_task_parents(projects=[project])
self.assertEqual(res.parents, [])
res = self.api.organization.get_entities_count(
projects={"id": [project]}, active_users=[user]
)
self.assertEqual(res.projects, 0)
# test aggregations with non-empty subprojects
task1 = self._temp_task(project=child)
self._temp_task(project=child, parent=task1)
user2_task = self._temp_task(project=child, client=user2_client)
framework = "Test framework"
self._temp_model(project=child, framework=framework)
res = self.api.users.get_all_ex(active_in_projects=[project])
self._assert_ids(res.users, [user])
res = self.api.projects.get_all_ex(id=[project], active_users=[user])
res = self.api.projects.get_all_ex(id=[project], include_stats=True)
self._assert_ids(res.projects, [project])
self.assertEqual(res.projects[0].stats.active.total_tasks, 3)
res = self.api.projects.get_all_ex(
id=[project], active_users=[user], include_stats=True
)
self._assert_ids(res.projects, [project])
self.assertEqual(res.projects[0].stats.active.total_tasks, 2)
res = self.api.projects.get_task_parents(projects=[project])
self._assert_ids(res.parents, [task1])
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual(res.frameworks, [framework])
res = self.api.tasks.get_types(projects=[project])
self.assertEqual(res.types, ["testing"])
res = self.api.organization.get_entities_count(
projects={"id": [project]}, active_users=[user]
)
self.assertEqual(res.projects, 1)
def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]):
self.assertEqual([a["id"] for a in actual], expected)
@@ -70,8 +107,11 @@ class TestSubProjects(TestService):
# update
with self.api.raises(errors.bad_request.CannotUpdateProjectLocation):
self.api.projects.update(project=project1, name="Root2/Pr2")
res = self.api.projects.update(project=project1, name="Root1/Pr2")
new_basename = "Pr2"
res = self.api.projects.update(project=project1, name=f"Root1/{new_basename}")
self.assertEqual(res.updated, 1)
res = self.api.projects.get_by_id(project=project1)
self.assertEqual(res.project.basename, new_basename)
res = self.api.projects.get_by_id(project=project1_child)
self.assertEqual(res.project.name, "Root1/Pr2/Pr2")
@@ -80,6 +120,7 @@ class TestSubProjects(TestService):
self.assertEqual(res.moved, 2)
res = self.api.projects.get_by_id(project=project1_child)
self.assertEqual(res.project.name, "Root2/Pr2/Pr2")
self.assertEqual(res.project.basename, "Pr2")
# merge
project_with_task, (active, archived) = self._temp_project_with_tasks(
@@ -102,6 +143,7 @@ class TestSubProjects(TestService):
self.assertEqual(res.moved_projects, 1)
res = self.api.projects.get_by_id(project=project_with_task)
self.assertEqual(res.project.name, "Root2/Pr2/Pr4")
self.assertEqual(res.project.basename, "Pr4")
with self.api.raises(errors.bad_request.InvalidProjectId):
self.api.projects.get_by_id(project=merge_source)
@@ -156,6 +198,11 @@ class TestSubProjects(TestService):
self.assertEqual([p.id for p in res], [project1])
res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects
self.assertEqual([p.id for p in res], [project2])
# basename search
res = self.api.projects.get_all_ex(
basename="project2", shallow_search=True
).projects
self.assertEqual(res, [])
# global search finds all or below the specified level
res = self.api.projects.get_all_ex(name="project1").projects
@@ -163,7 +210,9 @@ class TestSubProjects(TestService):
project4 = self._temp_project(name="project1/project2/project1")
res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects
self.assertEqual([p.id for p in res], [project4])
# basename search
res = self.api.projects.get_all_ex(basename="project2").projects
self.assertEqual([p.id for p in res], [project2])
self.api.projects.delete(project=project1, force=True)
def test_get_all_with_check_own_contents(self):
@@ -249,13 +298,14 @@ class TestSubProjects(TestService):
**kwargs,
)
def _temp_task(self, **kwargs):
def _temp_task(self, client=None, **kwargs):
return self.create_temp(
"tasks",
delete_params=self.delete_params,
type="testing",
name=db_id(),
input=dict(view=dict()),
client=client,
**kwargs,
)

View File

@@ -6,17 +6,24 @@ from typing import Sequence, Optional, Tuple
from boltons.iterutils import first
from apiserver.apierrors import errors
from apiserver.es_factory import es_factory
from apiserver.apierrors.errors.bad_request import EventsNotAdded
from apiserver.tests.automated import TestService
class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
)
return self.create_temp("tasks", **task_input)
return self.create_temp("tasks", delete_paramse=self.delete_params, **task_input)
def _temp_model(self, name="test model events", **kwargs):
self.update_missing(kwargs, name=name, uri="file:///a/b", labels={})
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
@staticmethod
def _create_task_event(type_, task, iteration, **kwargs):
@@ -67,6 +74,86 @@ class TestTaskEvents(TestService):
),
)
def test_task_single_value_metrics(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 10
task = self._temp_task()
special_iteration = -(2 ** 31)
events = [
{
**self._create_task_event(
"training_stats_scalar", task, iteration or special_iteration
),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
# special iteration is present in the events retrieval
metric_param = {"metric": metric, "variants": [variant]}
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=100, metric=metric_param, count_total=True
)
self.assertEqual(res.returned, iter_count)
self.assertEqual(res.total, iter_count)
self.assertEqual(
res.variants[variant]["iter"],
[x or special_iteration for x in range(iter_count)],
)
self.assertEqual(
res.variants[variant]["y"], list(range(iter_count))
)
# but not in the histogram
data = self.api.events.scalar_metrics_iter_histogram(task=task)
self.assertEqual(data[metric][variant]["x"], list(range(1, iter_count)))
# new api
res = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks
self.assertEqual(len(res), 1)
data = res[0]
self.assertEqual(data.task, task)
self.assertEqual(len(data["values"]), 1)
value = data["values"][0]
self.assertEqual(value.metric, metric)
self.assertEqual(value.variant, variant)
self.assertEqual(value.value, 0)
# update is working
task_data = self.api.tasks.get_by_id(task=task).task
last_metrics = first(first(task_data.last_metrics.values()).values())
self.assertEqual(last_metrics.value, iter_count - 1)
new_value = 1000
new_event = {
**self._create_task_event("training_stats_scalar", task, special_iteration),
"metric": metric,
"variant": variant,
"value": new_value,
}
self.send(new_event)
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=100, metric=metric_param, count_total=True
)
self.assertEqual(
res.variants[variant]["y"],
[y or new_value for y in range(iter_count)],
)
task_data = self.api.tasks.get_by_id(task=task).task
last_metrics = first(first(task_data.last_metrics.values()).values())
self.assertEqual(last_metrics.value, new_value)
data = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks[0]
self.assertEqual(data.task, task)
self.assertEqual(len(data["values"]), 1)
value = data["values"][0]
self.assertEqual(value.value, new_value)
def test_last_scalar_metrics(self):
metric = "Metric1"
variant = "Variant1"
@@ -92,6 +179,42 @@ class TestTaskEvents(TestService):
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(0, metric_data.min_value)
def test_model_events(self):
model = self._temp_model(ready=False)
# task log events are not allowed
log_event = self._create_task_event(
"log",
task=model,
iteration=0,
msg=f"This is a log message",
model_event=True,
)
with self.api.raises(errors.bad_request.EventsNotAdded):
self.send(log_event)
# send metric events and check that model data always have iteration 0 and only last data is saved
events = [
{
**self._create_task_event("training_stats_scalar", model, iteration),
"metric": f"Metric{metric_idx}",
"variant": f"Variant{variant_idx}",
"value": iteration,
"model_event": True,
}
for iteration in range(2)
for metric_idx in range(5)
for variant_idx in range(5)
]
self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram(task=model, model_events=True)
self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)])
metric_data = data.Metric0
self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)])
variant_data = metric_data.Variant0
self.assertEqual(variant_data.x, [0])
self.assertEqual(variant_data.y, [1.0])
def test_error_events(self):
task = self._temp_task()
events = [
@@ -475,7 +598,8 @@ class TestTaskEvents(TestService):
return data
def send(self, event):
self.api.send("events.add", event)
_, data = self.api.send("events.add", event)
return data
if __name__ == "__main__":

View File

@@ -0,0 +1,337 @@
from functools import partial
from typing import Sequence, Mapping, Optional
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService
class TestTaskPlots(TestService):
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
)
return self.create_temp("tasks", **task_input)
@staticmethod
def _create_task_event(task, iteration, **kwargs):
return {
"worker": "test",
"type": "plot",
"task": task,
"iter": iteration,
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
**kwargs,
}
def test_get_plot_sample(self):
task = self._temp_task()
metric = "Metric1"
variants = ["Variant1", "Variant2"]
# test empty
res = self.api.events.get_plot_sample(task=task, metric=metric)
self.assertEqual(res.min_iteration, None)
self.assertEqual(res.max_iteration, None)
self.assertEqual(res.events, [])
# test existing events
iterations = 5
events = [
self._create_task_event(
task=task,
iteration=n // len(variants),
metric=metric,
variant=variants[n % len(variants)],
plot_str=f"Test plot str {n}",
)
for n in range(iterations * len(variants))
]
self.send_batch(events)
# if iteration is not specified then return the event from the last one
res = self.api.events.get_plot_sample(task=task, metric=metric)
self._assertEqualEvents(res.events, events[-len(variants) :])
self.assertEqual(res.max_iteration, iterations - 1)
self.assertEqual(res.min_iteration, 0)
self.assertTrue(res.scroll_id)
# else from the specific iteration
iteration = 3
res = self.api.events.get_plot_sample(
task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
)
self._assertEqualEvents(
res.events,
events[iteration * len(variants) : (iteration + 1) * len(variants)],
)
def test_next_plot_sample(self):
task = self._temp_task()
metric1 = "Metric1"
metric2 = "Metric2"
metrics = [
(metric1, "variant1"),
(metric1, "variant2"),
(metric2, "variant3"),
(metric2, "variant4"),
]
# test existing events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"Test plot str {n}",
)
for n in range(2)
for metric, variant in metrics
]
self.send_batch(events)
# single metric navigation
# init scroll
res = self.api.events.get_plot_sample(task=task, metric=metric1)
self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self.assertEqual(res.events, [])
# navigate backwards
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-8:-6])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, [])
# all metrics navigation
# init scroll
res = self.api.events.get_plot_sample(
task=task, metric=metric1, navigate_current_metric=False
)
self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self._assertEqualEvents(res.events, events[-2:])
# navigate backwards
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-4:-2])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-6:-4])
# next_iteration
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True
)
self._assertEqualEvents(res.events, [])
res = self.api.events.next_plot_sample(
task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
)
self._assertEqualEvents(res.events, events[-4:-2])
self.assertTrue(all(ev.iter == 1 for ev in res.events))
res = self.api.events.next_plot_sample(
task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
)
self._assertEqualEvents(res.events, [])
def _assertEqualEvents(
self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
):
self.assertEqual(len(ev_source), len(ev_target))
def compare_event(ev1, ev2):
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
self.assertEqual(ev1[field], ev2[field])
for e1, e2 in zip(ev_source, ev_target):
compare_event(e1, e2)
def test_task_plots(self):
task = self._temp_task()
# test empty
res = self.api.events.plots(metrics=[{"task": task}], iters=5)
self.assertFalse(res.metrics[0].iterations)
res = self.api.events.plots(
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
)
self.assertFalse(res.metrics[0].iterations)
# test not empty
metrics = {
"Metric1": ["Variant1", "Variant2"],
"Metric2": ["Variant3", "Variant4"],
}
events = [
self._create_task_event(
task=task,
iteration=1,
metric=metric,
variant=variant,
plot_str=f"Test plot str {metric}_{variant}",
)
for metric, variants in metrics.items()
for variant in variants
]
self.send_batch(events)
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=1
)
# test refresh
update = {
"Metric2": ["Variant3", "Variant4", "Variant5"],
"Metric3": ["VariantA", "VariantB"],
}
events = [
self._create_task_event(
task=task,
iteration=2,
metric=metric,
variant=variant,
plot_str=f"Test plot str {metric}_{variant}_2",
)
for metric, variants in update.items()
for variant in variants
]
self.send_batch(events)
# without refresh the metric states are not updated
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
)
# with refresh there are new metrics and existing ones are updated
self._assertTaskMetrics(
task=task,
expected_metrics=update,
iterations=1,
scroll_id=scroll_id,
refresh=True,
)
def _assertTaskMetrics(
self,
task: str,
expected_metrics: Mapping[str, Sequence[str]],
iterations,
scroll_id: str = None,
refresh=False,
) -> str:
res = self.api.events.plots(
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
)
if not iterations:
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
expected_variants = set(
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
)
for metric_data in res.metrics:
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set((e.metric, e.variant) for e in it_data.events),
expected_variants,
)
return res.scroll_id
def test_plots_navigation(self):
task = self._temp_task()
metric = "Metric1"
variants = ["Variant1", "Variant2"]
iterations = 10
# test empty
res = self.api.events.plots(
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics[0].iterations)
# create events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"{metric}_{variant}_{n}",
)
for n in range(iterations)
for variant in variants
]
self.send_batch(events)
# init testing
scroll_id = None
assert_plots = partial(
self._assertPlots,
task=task,
metric=metric,
iterations=iterations,
variants=len(variants),
)
# test forward navigation
for page in range(3):
scroll_id = assert_plots(scroll_id=scroll_id, expected_page=page)
# test backwards navigation
scroll_id = assert_plots(
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
)
# beyond the latest iteration and back
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
scroll_id=scroll_id,
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_plots(scroll_id=scroll_id, expected_page=1)
# refresh
assert_plots(scroll_id=scroll_id, expected_page=0, refresh=True)
def _assertPlots(
self,
task,
metric,
iterations: int,
variants: int,
scroll_id,
expected_page: int,
iters: int = 5,
**extra_params,
) -> str:
res = self.api.events.plots(
metrics=[{"task": task, "metric": metric}],
iters=iters,
scroll_id=scroll_id,
**extra_params,
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
left_iterations = max(0, iterations - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
self.assertEqual(len(it["events"]), variants)
return res.scroll_id
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data

View File

@@ -50,10 +50,11 @@ class TestTasksResetDelete(TestService):
self.assertEqual(res.urls.artifact_urls, [])
task = self.new_task()
published_model_urls, draft_model_urls = self.create_task_models(task)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task)
artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task))
event_urls.update(self.send_model_events(model))
res = self.assert_delete_task(task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls)
@@ -120,10 +121,11 @@ class TestTasksResetDelete(TestService):
self, **kwargs
) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]:
task = self.new_task(**kwargs)
published_model_urls, draft_model_urls = self.create_task_models(task, **kwargs)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs)
artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task))
event_urls.update(self.send_model_events(model))
return task, (published_model_urls, draft_model_urls), artifact_urls, event_urls
def assert_delete_task(self, task_id, force=False, return_file_urls=False):
@@ -137,15 +139,17 @@ class TestTasksResetDelete(TestService):
self.assertEqual(tasks, [])
return res
def create_task_models(self, task, **kwargs) -> Tuple[Set[str], Set[str]]:
def create_task_models(self, task, **kwargs) -> Tuple:
"""
Update models from task and return only non public models
"""
model_ready = self.new_model(uri="ready", **kwargs)
model_not_ready = self.new_model(uri="not_ready", ready=False, **kwargs)
ready_uri = "ready"
not_ready_uri = "not_ready"
model_ready = self.new_model(uri=ready_uri, **kwargs)
model_not_ready = self.new_model(uri=not_ready_uri, ready=False, **kwargs)
self.api.models.edit(model=model_not_ready, task=task)
self.api.models.edit(model=model_ready, task=task)
return {"ready"}, {"not_ready"}
return (model_ready, {ready_uri}), (model_not_ready, {not_ready_uri})
def send_artifacts(self, task) -> Set[str]:
"""
@@ -159,6 +163,20 @@ class TestTasksResetDelete(TestService):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts)
return {"test2"}
def send_model_events(self, model) -> Set[str]:
url1 = "http://link1"
url2 = "http://link2"
events = [
self.create_event(
model, "training_debug_image", 0, url=url1, model_event=True
),
self.create_event(
model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True
)
]
self.send_batch(events)
return {url1, url2}
def send_debug_image_events(self, task) -> Set[str]:
events = [
self.create_event(

View File

@@ -42,6 +42,32 @@ class TestTasksFiltering(TestService):
self.assertEqual(res.total, 0)
self.assertEqual(res["values"], [])
def test_datetime_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()
for task in tasks:
self.api.tasks.ping(task=task)
# date time syntax
res = self.api.tasks.get_all_ex(last_update=f">={now.isoformat()}").tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
last_update=[
f">={(now - timedelta(seconds=60)).isoformat()}",
f"<={now.isoformat()}",
]
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
# simplified range syntax
res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
last_update=[(now - timedelta(seconds=60)).isoformat(), now.isoformat()]
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
def test_range_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()

Some files were not shown because too many files have changed in this diff Show More