Compare commits

89 Commits

Author SHA1 Message Date
allegroai
c8e4d9eeac Fix Dockerfile uses deprecated base image 2023-04-18 10:50:13 +03:00
dependabot[bot]
b51aa5c29b Bump redis from 3.5.3 to 4.4.4 in /apiserver (#190)
Bumps [redis](https://github.com/redis/redis-py) from 3.5.3 to 4.4.4.
- [Release notes](https://github.com/redis/redis-py/releases)
- [Changelog](https://github.com/redis/redis-py/blob/master/CHANGES)
- [Commits](https://github.com/redis/redis-py/compare/3.5.3...v4.4.4)

---
updated-dependencies:
- dependency-name: redis
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-01 08:59:02 +03:00
allegroai
e7c9daa42b Fix get_task_events to correctly use last_iters for model events 2023-03-28 16:45:44 +03:00
allegroai
7357654249 Version bump to v1.10 2023-03-23 19:17:00 +02:00
allegroai
a6f671b46a Fix typo 2023-03-23 19:16:38 +02:00
allegroai
17a8b440bd Fix only last event of each type is stored per model (all should be stored) 2023-03-23 19:16:30 +02:00
allegroai
eb2b9cbd9a Fix project count for datasets and pipelines 2023-03-23 19:15:42 +02:00
allegroai
797e503e67 Update ES version 2023-03-23 19:14:33 +02:00
allegroai
30cfdac8f2 Fix project preview completed_tasks_24h should not count tasks that are marked as failed or running 2023-03-23 19:13:52 +02:00
allegroai
24bb87aaee Turn on mongo sorting using disk usage by default for sorting in *.get_all* apis 2023-03-23 19:12:52 +02:00
allegroai
dd49ba180a Improve statistics on projects children 2023-03-23 19:11:45 +02:00
allegroai
bda903d0d8 Set API version to 2.24 2023-03-23 19:11:13 +02:00
allegroai
9739eb2d5a Add report_assets field to report tasks 2023-03-23 19:09:03 +02:00
allegroai
cfbb37238f Add default workers timeout to the server's configuration 2023-03-23 19:08:11 +02:00
allegroai
6664c6237e Support querying by children_type in projects.get_all_ex 2023-03-23 19:07:42 +02:00
allegroai
74200a24bd Add filtering on child projects in projects.get_all_ex 2023-03-23 19:06:49 +02:00
john-zielke-snkeos
2fb9288a6c Add env switch to disable nginx ipv6 bind (#165) 2023-03-13 16:05:43 +02:00
shyallegro
5d014d81af Fix #184 and update docker build to include widgets (#185) 2023-03-07 11:26:12 +02:00
allegroai
3a2675abe1 Version bump to v1.9.2 2023-01-24 16:11:21 +02:00
allegroai
f0d68b1ce9 Make sure model label values are integer 2023-01-24 16:11:12 +02:00
allegroai
15db9cdaef Allow updating comments on published reports 2023-01-24 14:40:32 +02:00
Mal Miller
a45d47f5d7 Fix default value of CLEARML_AGENT_UPDATE_VERSION for agent-services (#114) 2023-01-03 13:45:52 +02:00
allegroai
b1a50c1370 Version bump to v1.9.1 2023-01-03 12:16:07 +02:00
allegroai
22a2a02760 Allow renaming published reports 2023-01-03 12:15:44 +02:00
allegroai
ab798e4170 Allow updating tags on published reports 2023-01-03 12:15:02 +02:00
allegroai
f09ac672d2 Add pipeline test 2023-01-03 12:14:12 +02:00
allegroai
2149b76f63 Fix crash when starting pipeline 2023-01-03 12:13:48 +02:00
allegroai
d96420aa67 Version bump to v1.9 2022-12-21 18:47:03 +02:00
allegroai
ed6c7b7bcb Fix Project time is not updated when moved or merged 2022-12-21 18:46:53 +02:00
allegroai
a392bc0bd7 Bump API version to 2.23 2022-12-21 18:46:12 +02:00
allegroai
7e97ec5555 Fix events.get_task_plots endpoint 2022-12-21 18:45:17 +02:00
allegroai
9c41124b81 Add support for moving objects to projects root 2022-12-21 18:43:45 +02:00
allegroai
14ff639bb0 Removed limit on event comparison for the same company tasks only 2022-12-21 18:42:40 +02:00
allegroai
e66257761a Add support for server-side delete for AWS S3, Google Storage and Azure Blob Storage 2022-12-21 18:41:16 +02:00
allegroai
0ffde24dc2 Add min and max value iteration to last metrics 2022-12-21 18:36:50 +02:00
allegroai
d4fdcd9b32 Upgrade mongoengine version 2022-12-21 18:35:23 +02:00
allegroai
18570bfccb Add project_id response field to reports.create endpoint 2022-12-21 18:35:14 +02:00
allegroai
54ce6c34c6 Fix bad field values might cause ugly server exception to be returned 2022-12-21 18:33:28 +02:00
allegroai
ae4c33fa0e Add support for allow_public flag in get_all_ex endpoint
Add `last_changed_by` field on task updates
Fix reports support
2022-12-21 18:32:56 +02:00
allegroai
c7cd949fd0 Add reports support
Fix schema
2022-12-21 18:30:54 +02:00
allegroai
1ce4058157 Change tasks comparison limit to 100 2022-12-21 18:29:49 +02:00
allegroai
7b6f24b24d Version bump to 1.8.0 2022-11-29 17:50:32 +02:00
allegroai
d03a931d84 Remove buggy debug code 2022-11-29 17:50:17 +02:00
allegroai
5cc7199661 Fix clearing ES scroll 2022-11-29 17:44:31 +02:00
allegroai
6537e9ef69 Add active_users and search_hidden options to get_entities_count endpoint 2022-11-29 17:44:19 +02:00
allegroai
930aaff791 Fix test 2022-11-29 17:43:43 +02:00
allegroai
1999fb2479 Add URL async delete prefixes setting 2022-11-29 17:43:08 +02:00
allegroai
9db14cc31d Update endpoint API versions 2022-11-29 17:42:19 +02:00
allegroai
e3cc689528 Enhance async_urls_delete feature with max_async_deleted_events_per_sec setting and fileserver timeout and prefixes 2022-11-29 17:41:49 +02:00
allegroai
9e0adc77dd Make sure id field is included in ordering 2022-11-29 17:40:09 +02:00
allegroai
58d9a64537 Fix endpoint name in schema 2022-11-29 17:39:35 +02:00
allegroai
d397d2ae20 Support optional async events deletion when deleting tasks 2022-11-29 17:38:41 +02:00
allegroai
2d711e1500 Collect model event URLs during task and project cleanup 2022-11-29 17:38:03 +02:00
allegroai
97992b0d9e Support returning multiple plots during history navigation 2022-11-29 17:37:30 +02:00
allegroai
bc23f1b0cf Add "queue watched" indication for tasks.enqueue and tasks.enqueue_many 2022-11-29 17:36:41 +02:00
allegroai
6b3eff1426 Support workers system tags 2022-11-29 17:35:25 +02:00
allegroai
caaf801cd0 Support plots navigation by iteration 2022-11-29 17:34:57 +02:00
allegroai
c23e8a90d0 Support model events 2022-11-29 17:34:06 +02:00
allegroai
fa5b28ca0e Support HTTP URLs when deleting fileserver files 2022-11-29 17:33:18 +02:00
allegroai
bfb55a9463 Support deleting external artifacts when deleting projects 2022-11-29 17:32:41 +02:00
allegroai
37e485e1f2 Upgrade API version to 2.22 2022-11-29 17:29:57 +02:00
allegroai
3451ff441f Fix projects.get_all_ex with active_users filtering did not work if the project id was passed as a string and not list 2022-11-29 17:27:54 +02:00
allegroai
53c9b5525e Add preliminary support for datasets under projects 2022-11-29 17:27:02 +02:00
PSL
e5230edac3 Add mongo username and password authentication (#162)
Co-authored-by: pangshaoliang <pangshaoliang@megvii.com>
2022-10-08 15:56:02 +03:00
allegroai
a54dd8030c Add async-delete to docker-compose 2022-09-29 19:44:05 +03:00
allegroai
482a5c34bc Version bump to 1.7.0 2022-09-29 19:40:20 +03:00
allegroai
ee2a72c70f Add company/uri index 2022-09-29 19:40:05 +03:00
allegroai
a0d8aaf3b9 Fix urls are not unquoted in batch_delete 2022-09-29 19:39:02 +03:00
allegroai
de1f823213 Removed stub timing context 2022-09-29 19:37:15 +03:00
allegroai
0c9e2f92ee Add server-side support for deleting files from fileserver on task delete 2022-09-29 19:34:24 +03:00
allegroai
6c49e96ff0 Update API version to 2.21 2022-09-29 19:31:42 +03:00
allegroai
81e3fc6577 Improve utilities 2022-09-29 19:30:57 +03:00
allegroai
e6dc4b7557 Set cloned task parent to original task if original task has no parent 2022-09-29 19:30:13 +03:00
allegroai
238a47a197 Add created field to backend.user 2022-09-29 19:29:36 +03:00
allegroai
04e7076628 Add support for specifying a specific task ID in queues.get_next_task 2022-09-29 19:27:42 +03:00
allegroai
0531612bf4 Fix deleting a queue should dequeue all enqueued tasks 2022-09-29 19:27:09 +03:00
allegroai
3ae410a1e9 Remove the ThreadsManager.terminating flag 2022-09-29 19:23:26 +03:00
allegroai
98ed3075dd Added exclude support when converting mongo objects to dictionary 2022-09-29 19:21:28 +03:00
allegroai
b871bf4224 Fix projects.get_all_ex 2022-09-29 19:20:50 +03:00
allegroai
8d4c02fc3c Add support for hidden internal queues 2022-09-29 19:20:24 +03:00
allegroai
b986980c75 Use correct attrs version, remove related dependency 2022-09-29 19:18:38 +03:00
allegroai
a4fa567be2 Fix task stats update 2022-09-29 19:18:22 +03:00
allegroai
ddb91f226a Add Task Unique Metrics to task object 2022-09-29 19:16:56 +03:00
allegroai
7772f47773 Support datetime ranges in field queries 2022-09-29 19:15:50 +03:00
allegroai
9c118d14e0 Add missing last_update field to models.get_all APIs 2022-09-29 19:14:13 +03:00
allegroai
efd56e085e Fix threaded jobs management (invoke only from AppSequence) 2022-09-29 19:13:22 +03:00
allegroai
4dff163af4 Improve examples pre-population code 2022-09-29 19:11:21 +03:00
allegroai
242a78a0fe Add request logging using the CLEARML_SERVER_DEBUG_REQUESTS env var 2022-09-29 19:10:55 +03:00
allegroai
78989fea91 Add better mongodb connection string verbosity 2022-09-29 19:10:13 +03:00
118 changed files with 7368 additions and 2952 deletions

View File

@@ -50,6 +50,9 @@
130: ["task_not_found", "task not found"]
131: ["events_not_added", "events not added"]
# Reports
150: ["operation_supported_on_reports_only", "passed task is not report"]
# Models
200: ["model_error", "general task error"]
201: ["invalid_model_id", "invalid model id"]

View File

@@ -61,12 +61,6 @@ class ListField(fields.ListField):
item.validate()
# since there is no distinction between None and empty DictField
# this value can be used as sentinel in order to distinguish
# between not set and empty DictField
DictFieldNotSet = {}
class DictField(fields.BaseField):
types = (dict,)

View File

@@ -26,6 +26,7 @@ class MetricVariants(Base):
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@@ -35,11 +36,12 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
Length(
minimum_value=1,
maximum_value=config.get(
"services.tasks.multi_task_histogram_limit", 10
"services.tasks.multi_task_histogram_limit", 100
),
)
],
)
model_events: bool = BoolField(default=False)
class TaskMetric(Base):
@@ -56,25 +58,36 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField()
class TaskMetricVariant(Base):
class GetVariantSampleRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
variant: str = StringField(required=True)
class GetHistorySampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetMetricSamplesRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class NextHistorySampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
next_iteration: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)
class LogOrderEnum(StringEnum):
@@ -93,6 +106,7 @@ class TaskEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class LogEventsRequest(TaskEventsRequestBase):
@@ -108,6 +122,7 @@ class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class IterationEvents(Base):
@@ -129,6 +144,7 @@ class MultiTasksRequestBase(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
model_events: bool = BoolField(default=False)
class SingleValueMetricsRequest(MultiTasksRequestBase):
@@ -145,6 +161,7 @@ class TaskPlotsRequest(Base):
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base):

View File

@@ -79,3 +79,4 @@ class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@@ -19,3 +19,7 @@ class EntitiesCountRequest(models.Base):
models = DictField()
pipelines = DictField()
datasets = DictField()
reports = DictField()
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)

View File

@@ -1,3 +1,5 @@
from enum import Enum
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField, DictField
@@ -56,14 +58,22 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
class ProjectChildrenType(Enum):
pipeline = "pipeline"
report = "report"
dataset = "dataset"
class ProjectsGetRequest(models.Base):
include_dataset_stats = fields.BoolField(default=False)
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
non_public = fields.BoolField(default=False) # legacy, use allow_public instead
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)

View File

@@ -30,9 +30,15 @@ class GetByIdRequest(QueueRequest):
max_task_entries = IntField()
class GetAllRequest(Base):
max_task_entries = IntField()
search_hidden = BoolField(default=False)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
task = StringField()
class DeleteRequest(QueueRequest):

View File

@@ -0,0 +1,72 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField, ListField, BoolField, EmbeddedField, IntField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apiserver.apimodels.events import MetricVariants, HistogramRequestBase
class UpdateReportRequest(Base):
task = StringField(required=True)
name = StringField(nullable=True, validators=Length(minimum_value=3))
tags = ListField(items_types=[str])
comment = StringField()
report = StringField()
report_assets = ListField(items_types=[str])
class CreateReportRequest(Base):
name = StringField(required=True, validators=Length(minimum_value=3))
tags = ListField(items_types=[str])
comment = StringField()
report = StringField()
project = StringField()
report_assets = ListField(items_types=[str])
class PublishReportRequest(Base):
task = StringField(required=True)
message = StringField(default="")
class ArchiveReportRequest(Base):
task = StringField(required=True)
message = StringField(default="")
class ShareReportRequest(Base):
task = StringField(required=True)
share = BoolField(default=True)
class DeleteReportRequest(Base):
task = StringField(required=True)
force = BoolField(default=False)
class MoveReportRequest(Base):
task = StringField(required=True)
project = StringField()
project_name = StringField()
class EventsRequest(Base):
iters = IntField(default=1, validators=validators.Min(1))
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ScalarMetricsIterHistogram(HistogramRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class GetTasksDataRequest(Base):
debug_images: EventsRequest = EmbeddedField(EventsRequest)
plots: EventsRequest = EmbeddedField(EventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
allow_public = BoolField(default=True)
class GetAllRequest(Base):
allow_public = BoolField(default=True)

View File

@@ -42,6 +42,7 @@ class StartedResponse(UpdateResponse):
class EnqueueResponse(UpdateResponse):
queued = IntField()
queue_watched = BoolField()
class EnqueueBatchItem(UpdateBatchItem):
@@ -50,6 +51,7 @@ class EnqueueBatchItem(UpdateBatchItem):
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
queue_watched = BoolField()
class DequeueResponse(UpdateResponse):
@@ -97,12 +99,14 @@ class UpdateRequest(TaskUpdateRequest):
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
class SetRequirementsRequest(TaskRequest):
@@ -180,6 +184,7 @@ class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
class MultiTaskRequest(models.Base):
@@ -273,6 +278,7 @@ class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
validate_tasks = BoolField(default=False)
verify_watched_queue = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):
@@ -280,6 +286,7 @@ class DeleteManyRequest(TaskBatchRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
class ResetManyRequest(TaskBatchRequest):
@@ -287,6 +294,7 @@ class ResetManyRequest(TaskBatchRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
class PublishManyRequest(TaskBatchRequest):
@@ -310,3 +318,8 @@ class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1)
)
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)

View File

@@ -20,12 +20,11 @@ DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
system_tags = ListField(str)
class RegisterRequest(WorkerRequest):
timeout = make_default(
IntField, DEFAULT_TIMEOUT
)() # registration timeout in seconds (default is 10min)
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min)
queues = ListField(six.string_types) # list of queues this worker listens to
@@ -76,6 +75,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
tags = ListField(str)
system_tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):
@@ -97,6 +97,7 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
class GetAllResponse(Base):

View File

@@ -9,6 +9,7 @@ from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import elasticsearch
from boltons.iterutils import chunked_iter
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
@@ -26,11 +27,12 @@ from apiserver.bll.event.event_common import (
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.event_metrics import EventMetrics
@@ -39,9 +41,8 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items, nested_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
# noinspection PyTypeChecker
@@ -63,12 +64,21 @@ class PlotFields:
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
event_id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
flags=re.IGNORECASE,
)
_task_event_query = {
"bool": {
"should": [
{"term": {"model_event": False}},
{"bool": {"must_not": [{"exists": {"field": "model_event"}}]}},
]
}
}
_model_event_query = {"term": {"model_event": True}}
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
@@ -84,7 +94,7 @@ class EventBLL(object):
es=self.es, redis=self.redis
)
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
self.events_iterator = EventsIterator(es=self.es)
@property
@@ -97,18 +107,42 @@ class EventBLL(object):
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
with translate_errors_context():
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
@staticmethod
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
"""Verify that task exists and can be updated"""
if not model_ids:
return set()
with translate_errors_context():
query = Q(id__in=model_ids, company=company_id)
if not allow_locked_models:
query &= Q(ready__ne=True)
res = Model.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
self, company_id, events, worker, allow_locked=False
) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False)
for event in events:
if event.get("model_event", model_events) != model_events:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events"
)
if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
)
actions: List[dict] = []
task_ids = set()
task_or_model_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
3, dict
@@ -118,13 +152,28 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked_tasks,
)
if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked,
)
entity_name = "task"
for event in events:
# remove spaces from event type
@@ -138,13 +187,17 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
task_id = event.get("task")
if task_id is None:
if model_events and event_type == EventType.task_log.value:
errors_per_type[f"Task log events are not supported for models"] += 1
continue
task_or_model_id = event.get("task")
if task_or_model_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
if task_or_model_id not in valid_entities:
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
continue
event["type"] = event_type
@@ -179,6 +232,7 @@ class EventBLL(object):
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type)
es_action = {
@@ -193,21 +247,26 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
task_ids.add(task_id)
task_or_model_ids.add(task_or_model_id)
if (
iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
task_iteration[task_or_model_id] = max(
iter, task_iteration[task_or_model_id]
)
if not model_events:
self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
)
actions.append(es_action)
plot_actions = [
@@ -228,39 +287,41 @@ class EventBLL(object):
with translate_errors_context():
if actions:
chunk_size = 500
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += 1
else:
errors_per_type["Error when indexing events batch"] += 1
if not model_events:
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
task_id=task_or_model_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
iter_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
last_events=task_last_events.get(task_or_model_id),
)
if not updated:
remaining_tasks.add(task_id)
remaining_tasks.add(task_or_model_id)
continue
if remaining_tasks:
@@ -361,8 +422,22 @@ class EventBLL(object):
for k in ("value", "metric", "variant", "iter", "timestamp")
if k in event
}
event_data["min_value"] = min(value, last_event.get("min_value", value))
event_data["max_value"] = max(value, last_event.get("max_value", value))
last_event_min_value = last_event.get("min_value", value)
last_event_min_value_iter = last_event.get("min_value_iter", event_iter)
if value < last_event_min_value:
event_data["min_value"] = value
event_data["min_value_iter"] = event_iter
else:
event_data["min_value"] = last_event_min_value
event_data["min_value_iter"] = last_event_min_value_iter
last_event_max_value = last_event.get("max_value", value)
last_event_max_value_iter = last_event.get("max_value_iter", event_iter)
if value > last_event_max_value:
event_data["max_value"] = value
event_data["max_value_iter"] = event_iter
else:
event_data["max_value"] = last_event_max_value
event_data["max_value_iter"] = last_event_max_value_iter
last_events[metric_hash][variant_hash] = event_data
def _update_last_metric_events_for_task(self, last_events, event):
@@ -396,36 +471,20 @@ class EventBLL(object):
as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
update time.
"""
fields = {}
if iter_max is not None:
fields["last_iteration_max"] = iter_max
if last_scalar_events:
fields["last_scalar_values"] = list(
flatten_nested_items(
last_scalar_events,
nesting=2,
include_leaves=[
"value",
"min_value",
"max_value",
"metric",
"variant",
],
)
)
if last_events:
fields["last_events"] = last_events
if not fields:
if iter_max is None and not last_events and not last_scalar_events:
return False
return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
return TaskBLL.update_statistics(
task_id,
company_id,
last_update=now,
last_iteration_max=iter_max,
last_scalar_events=last_scalar_events,
last_events=last_events,
)
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.id_fields if field in event)
id_values = (str(event[field]) for field in self.event_id_fields if field in event)
return hashlib.md5("-".join(id_values).encode()).hexdigest()
def scroll_task_events(
@@ -441,7 +500,7 @@ class EventBLL(object):
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "task_log_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
size = min(batch_size, 10000)
@@ -454,7 +513,7 @@ class EventBLL(object):
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@@ -469,31 +528,46 @@ class EventBLL(object):
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
def get_task_plots(
self,
company_id: str,
task_id: str,
num_last_iterations: int,
event_type: EventType,
last_iterations_per_plot: int,
metric_variants: MetricVariants = None,
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
event_type = EventType.metrics_plot
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
must = [{"term": {"task": task_id}}]
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition, {"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
es=self.es, company_id=company_id, event_type=event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // num_last_iterations)
max_variants = int(max_variants // last_iterations_per_plot)
es_req: dict = {
es_req = {
"sort": [{"iter": {"order": "desc"}}],
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
@@ -509,11 +583,10 @@ class EventBLL(object):
"order": {"_key": "asc"},
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": num_last_iterations,
"order": {"_key": "desc"},
"events": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": last_iterations_per_plot
}
}
},
@@ -521,117 +594,28 @@ class EventBLL(object):
},
}
},
"query": query,
}
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = search_company_events(body=es_req, **search_args)
with translate_errors_context():
es_response = search_company_events(
body=es_req,
ignore=404,
**search_args,
)
if "aggregations" not in es_res:
return []
return [
(metric["key"], variant["key"], iter["key"])
for metric in es_res["aggregations"]["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
for iter in variant["iters"]["buckets"]
]
def get_task_plots(
self,
company_id: str,
tasks: Sequence[str],
last_iterations_per_plot: int = None,
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
aggs_result = es_response.get("aggregations")
if not aggs_result:
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = EventType.metrics_plot
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition]
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
company_id=company_id,
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
for metric, variant, iter in last_iters:
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"term": {"iter": iter}},
]
}
}
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
events = [
hit["_source"]
for metrics_bucket in aggs_result["metrics"]["buckets"]
for variants_bucket in metrics_bucket["variants"]["buckets"]
for hit in variants_bucket["events"]["hits"]["hits"]
]
self.uncompress_plots(events)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
events=events, total_events=len(events)
)
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
@@ -643,7 +627,7 @@ class EventBLL(object):
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.es.clear_scroll(scroll_id=next_scroll_id)
self.clear_scroll(next_scroll_id)
next_scroll_id = self.empty_scroll
return events, total_events, next_scroll_id
@@ -721,11 +705,10 @@ class EventBLL(object):
def get_task_events(
self,
company_id: str,
task_id: str,
company_id: Union[str, Sequence[str]],
task_id: Union[str, Sequence[str]],
event_type: EventType,
metric=None,
variant=None,
metrics: MetricVariants = None,
last_iter_count=None,
sort=None,
size=500,
@@ -736,28 +719,38 @@ class EventBLL(object):
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
task_ids = (
[task_id]
if isinstance(task_id, str)
else task_id
)
must = []
if metric:
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
if metrics:
must.append(get_metric_variants_condition(metrics))
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
tasks_iters = self.get_last_iters(
company_id=company_id,
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
@@ -784,10 +777,10 @@ class EventBLL(object):
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
company_id=company_ids,
event_type=event_type,
body=es_req,
ignore=404,
@@ -809,9 +802,7 @@ class EventBLL(object):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@@ -838,9 +829,7 @@ class EventBLL(object):
"query": query,
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
metrics = {}
@@ -867,9 +856,7 @@ class EventBLL(object):
]
}
}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@@ -915,9 +902,7 @@ class EventBLL(object):
},
"_source": {"excludes": []},
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
metrics = []
@@ -963,7 +948,7 @@ class EventBLL(object):
"_source": ["iter", "value"],
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
@@ -978,15 +963,26 @@ class EventBLL(object):
def get_last_iters(
self,
company_id: str,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = [{"terms": {"task": task_ids}}]
if metrics:
must.append(get_metric_variants_condition(metrics))
es_req: dict = {
"size": 0,
"aggs": {
@@ -1003,12 +999,12 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
self.es, company_id=company_ids, event_type=event_type, body=es_req,
)
if "aggregations" not in es_res:
@@ -1019,6 +1015,21 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
@staticmethod
def _validate_model_state(
company_id: str, model_id: str, allow_locked: bool = False
):
extra_msg = None
query = Q(id=model_id, company=company_id)
if not allow_locked:
query &= Q(ready__ne=True)
extra_msg = "or model published"
res = Model.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidModelId(
extra_msg, company=company_id, id=model_id
)
@staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None
@@ -1032,22 +1043,42 @@ class EventBLL(object):
extra_msg, company=company_id, id=task_id
)
def delete_task_events(self, company_id, task_id, allow_locked=False):
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
@staticmethod
def _get_events_deletion_params(async_delete: bool) -> dict:
if async_delete:
return {
"wait_for_completion": False,
"requests_per_second": config.get(
"services.events.max_async_deleted_events_per_sec", 1000
),
}
return {"refresh": True}
def delete_task_events(
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
):
if model:
self._validate_model_state(
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
)
else:
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
**self._get_events_deletion_params(async_delete),
)
return es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
def clear_task_log(
self,
@@ -1064,7 +1095,7 @@ class EventBLL(object):
):
return 0
with translate_errors_context(), TimingContext("es", "clear_task_log"):
with translate_errors_context():
must = [{"term": {"task": task_id}}]
sort = None
if threshold_sec:
@@ -1092,24 +1123,29 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
def delete_multi_task_events(
self, company_id: str, task_ids: Sequence[str], async_delete=False
):
"""
Delete mutliple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
)
deleted = 0
with translate_errors_context():
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
**self._get_events_deletion_params(async_delete),
)
if not async_delete:
deleted += es_res.get("deleted", 0)
return es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:

View File

@@ -8,7 +8,7 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
@@ -21,8 +21,9 @@ class EventType(Enum):
all = "*"
SINGLE_SCALAR_ITERATION = -2**31
SINGLE_SCALAR_ITERATION = -(2 ** 31)
MetricVariants = Mapping[str, Sequence[str]]
TaskCompanies = Mapping[str, Sequence[Task]]
class EventSettings:
@@ -53,9 +54,12 @@ class EventSettings:
return int(self._max_es_allowed_aggregation_buckets * percentage)
def get_index_name(company_id: str, event_type: str):
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id.lower()}"
if isinstance(company_id, str):
company_id = [company_id]
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
@@ -80,9 +84,7 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(
index=es_index, body=body, conflicts="proceed", **kwargs
)
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
def count_company_events(
@@ -116,9 +118,7 @@ def get_max_metric_and_variant_counts(
"query": query,
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
}
with translate_errors_context(), TimingContext(
"es", "get_max_metric_and_variant_counts"
):
with translate_errors_context():
es_res = search_company_events(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)

View File

@@ -8,9 +8,7 @@ from typing import Sequence, Tuple, Mapping
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
@@ -20,12 +18,11 @@ from apiserver.bll.event.event_common import (
get_metric_variants_condition,
get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION,
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
log = config.logger(__file__)
@@ -111,57 +108,51 @@ class EventMetrics:
def compare_scalar_metrics_average_per_iter(
self,
company_id,
task_ids: Sequence[str],
companies: TaskCompanies,
samples,
key: ScalarKeyEnum,
allow_public=True,
metric_variants: MetricVariants = None,
):
"""
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name", "company", "company_origin"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
event_type = EventType.metrics_scalar
company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
companies = {
company_id: tasks
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=event_type
)
}
if not companies:
return {}
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
)
task_ids, company_ids = zip(
*(
(t.id, t.company)
for t in itertools.chain.from_iterable(companies.values())
)
)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id]
task_name = task_names[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
@@ -170,18 +161,27 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
self, companies: TaskCompanies
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
):
companies = {
company_id: [t.id for t in tasks]
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
)
}
if not companies:
return {}
with TimingContext("es", "get_task_single_value_metrics"):
task_events = self._get_task_single_value_metrics(company_id, task_ids)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
),
)
def _get_value(event: dict):
return {
@@ -195,8 +195,9 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
self, tasks: Tuple[str, Sequence[str]]
) -> Sequence[dict]:
company_id, task_ids = tasks
es_req = {
"size": 10000,
"query": {
@@ -277,9 +278,7 @@ class EventMetrics:
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@@ -312,8 +311,7 @@ class EventMetrics:
},
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = search_company_events(body=es_req, **search_args)
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
@@ -366,9 +364,7 @@ class EventMetrics:
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@@ -493,10 +489,9 @@ class EventMetrics:
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
return [
metric["key"]

View File

@@ -17,7 +17,6 @@ from apiserver.bll.event.event_common import (
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
@@ -76,7 +75,7 @@ class EventsIterator:
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
with translate_errors_context():
es_result = count_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
@@ -113,7 +112,7 @@ class EventsIterator:
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_result = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)

View File

@@ -1,56 +1,455 @@
from typing import Sequence, Tuple, Callable
import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, BoolField, ListField
from jsonmodels.models import Base
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .history_sample_iterator import HistorySampleIterator, VariantState
from .event_common import (
EventType,
EventSettings,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class HistoryDebugImageIterator(HistorySampleIterator):
class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugImageSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class VariantSampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class HistoryDebugImageIterator:
event_type = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> VariantSampleResult:
"""
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = VariantSampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
variants_conditions = [
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if next_iteration:
event = self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event:
return res
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
@staticmethod
def _fill_res_and_update_state(
event: dict, res: VariantSampleResult, state: DebugImageSampleState
):
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@staticmethod
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in metric_variants
]
return {"bool": {"should": variants_conditions}}
metrics_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
{"term": {"metric": metric}},
_get_variants_conditions(metric_variants),
]
}
}
for v in variants
for metric, metric_variants in metrics.items()
]
return {"bool": {"should": variants_conditions}}
return {"bool": {"should": metrics_conditions}}
def _process_event(self, event: dict) -> dict:
return event
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
# The min iteration is the lowest iteration that contains non-recycled image url
aggs = {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"iter": state.iteration}},
self._get_metric_conditions(variants),
{"exists": {"field": "url"}},
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = variants
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_conditions(variants),
{"range": {"iter": {range_operator: state.iteration}}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_sample_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> VariantSampleResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = VariantSampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: DebugImageSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugImageSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugImageSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
{"exists": {"field": "url"}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
},
}
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
es_res = search_company_events(body=es_req, **search_args)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter
return variant, min_iter, max_iter
return aggs, get_min_max_data
return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}

View File

@@ -1,36 +0,0 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from .event_common import EventType, uncompress_plot
from .history_sample_iterator import HistorySampleIterator, VariantState
class HistoryPlotIterator(HistorySampleIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_plot)
def _get_extra_conditions(self) -> Sequence[dict]:
return []
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
return {"terms": {"variant": [v.name for v in variants]}}
def _process_event(self, event: dict) -> dict:
uncompress_plot(event)
return event
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
# The min iteration is the lowest iteration that contains non-recycled image url
aggs = {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
}
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
min_iter = int(variant_bucket["first_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter
return aggs, get_min_max_data

View File

@@ -0,0 +1,316 @@
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, ListField, BoolField
from jsonmodels.models import Base
from redis.client import StrictRedis
from .event_common import (
EventType,
uncompress_plot,
EventSettings,
check_empty_data,
search_company_events,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.utilities.dicts import nested_get
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class MetricState(Base):
name: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class PlotsSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
task: str = StringField()
metric: str = StringField()
metric_states: Sequence[MetricState] = ListField([MetricState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class MetricSamplesResult(object):
scroll_id: str = None
events: list = []
min_iteration: int = None
max_iteration: int = None
class HistoryPlotsIterator:
event_type = EventType.metrics_plot
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=PlotsSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> MetricSamplesResult:
"""
Get the samples for next/prev metric on the current iteration
If does not exist then try getting sample for the first/last metric from next/prev iteration
"""
res = MetricSamplesResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if navigate_earlier:
range_operator = "lt"
order = "desc"
else:
range_operator = "gt"
order = "asc"
must_conditions = [
{"term": {"task": state.task}},
]
if state.navigate_current_metric:
must_conditions.append({"term": {"metric": state.metric}})
next_iteration_condition = {
"range": {"iter": {range_operator: state.iteration}}
}
if next_iteration or state.navigate_current_metric:
must_conditions.append(next_iteration_condition)
else:
next_metric_condition = {
"bool": {
"must": [
{"term": {"iter": state.iteration}},
{"range": {"metric": {range_operator: state.metric}}},
]
}
}
must_conditions.append(
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
)
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order=order,
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def get_samples_for_metric(
self,
company_id: str,
task: str,
metric: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> MetricSamplesResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = MetricSamplesResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: PlotsSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_metric_states(company_id=company_id, state=state_)
def validate_state(state_: PlotsSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_metric_states(company_id=company_id, state=state_)
state: PlotsSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
if not metric_state:
return res
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
]
if iteration is not None:
must_conditions.append({"range": {"iter": {"lte": iteration}}})
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order="desc",
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
return res
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
metrics = self._get_metric_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.metric_states = [
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
for metric, (min_iter, max_iter) in metrics.items()
]
def _get_metric_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[int, int]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 5000,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
},
}
},
}
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=self.event_type,
)
return {
metric_bucket["key"]: (
int(metric_bucket["first_iter"]["value"]),
int(metric_bucket["last_iter"]["value"]),
)
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}
@staticmethod
def _fill_res_and_update_state(
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
):
for event in events:
uncompress_plot(event)
state.metric = events[0]["metric"]
state.iteration = events[0]["iter"]
res.events = events
metric_state = first(
ms for ms in state.metric_states if ms.name == state.metric
)
if metric_state:
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
def _get_metric_events_for_condition(
self, company_id: str, task: str, order: str, must_conditions: Sequence
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 1,
"order": {"_key": order},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"variant": {"order": "asc"}},
"size": 100,
}
}
},
},
},
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return []
for level in ("iters", "metrics"):
level_data = aggs_result[level]["buckets"]
if not level_data:
return []
aggs_result = level_data[0]
return [
hit["_source"]
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
]

View File

@@ -1,442 +0,0 @@
import abc
import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Callable, Mapping
import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class HistorySampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class HistorySampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class HistorySampleIterator(abc.ABC):
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
self.es = es
self.event_type = event_type
self.cache_manager = RedisCacheManager(
state_class=HistorySampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> HistorySampleResult:
"""
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = HistorySampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
event = self._get_next_for_current_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if not event:
return res
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, event: dict, res: HistorySampleResult, state: HistorySampleState
):
self._process_event(event)
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@abc.abstractmethod
def _get_extra_conditions(self) -> Sequence[dict]:
pass
@abc.abstractmethod
def _process_event(self, event: dict) -> dict:
pass
@abc.abstractmethod
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
pass
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
metrics_conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
self._get_variants_conditions(vs),
]
}
}
for metric, vs in metrics.items()
]
return {"bool": {"should": metrics_conditions}}
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
) -> Optional[dict]:
"""
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"iter": state.iteration}},
self._get_metric_variants_condition(variants),
*self._get_extra_conditions(),
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
) -> Optional[dict]:
"""
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = variants
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_variants_condition(variants),
{"range": {"iter": {range_operator: state.iteration}}},
*self._get_extra_conditions(),
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_sample_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> HistorySampleResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = HistorySampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: HistorySampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: HistorySampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: HistorySampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
*self._get_extra_conditions(),
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_history_sample_for_variant"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
@abc.abstractmethod
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
pass
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[str, str, int, int]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
*self._get_extra_conditions(),
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": min_max_aggs,
}
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_history_sample_iterations"
):
es_res = search_company_events(body=es_req, **search_args,)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
min_iter, max_iter = get_min_max_data(variant_bucket)
return variant, min_iter, max_iter
return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}

View File

@@ -19,14 +19,14 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition, get_max_metric_and_variant_counts,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantState(Base):
@@ -75,18 +75,25 @@ class MetricEventsIterator:
def get_task_events(
self,
company_id: str,
companies: Mapping[str, str],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> MetricEventsResult:
if check_empty_data(self.es, company_id, self.event_type):
companies = {
task_id: company_id
for task_id, company_id in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
)
}
if not companies:
return MetricEventsResult()
def init_state(state_: MetricEventsScrollState):
state_.tasks = self._init_task_states(company_id, task_metrics)
state_.tasks = self._init_task_states(companies, task_metrics)
def validate_state(state_: MetricEventsScrollState):
"""
@@ -95,7 +102,7 @@ class MetricEventsIterator:
Refresh the state if requested
"""
if refresh:
self._reinit_outdated_task_states(company_id, state_, task_metrics)
self._reinit_outdated_task_states(companies, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
@@ -112,7 +119,7 @@ class MetricEventsIterator:
pool.map(
partial(
self._get_task_metric_events,
company_id=company_id,
companies=companies,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
specific_variants_requested=specific_variants_requested,
@@ -125,7 +132,7 @@ class MetricEventsIterator:
def _reinit_outdated_task_states(
self,
company_id,
companies: Mapping[str, str],
state: MetricEventsScrollState,
task_metrics: Mapping[str, dict],
):
@@ -133,9 +140,7 @@ class MetricEventsIterator:
Determine the metrics for which new event_type events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
def get_last_update_times_for_task_metrics(
task: Task,
@@ -175,7 +180,7 @@ class MetricEventsIterator:
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
@@ -205,14 +210,14 @@ class MetricEventsIterator:
]
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, dict]
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id),
partial(self._init_metric_states_for_task, companies=companies),
task_metrics.items(),
)
@@ -226,17 +231,20 @@ class MetricEventsIterator:
pass
@abc.abstractmethod
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
def _get_variant_state_aggs(
self,
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
pass
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], company_id: str
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any event_type events
"""
task, metrics = task_metrics
company_id = companies[task]
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
if metrics:
must.append(get_metric_variants_condition(metrics))
@@ -268,14 +276,18 @@ class MetricEventsIterator:
"size": max_variants,
"order": {"_key": "asc"},
},
**({"aggs": variant_state_aggs} if variant_state_aggs else {}),
**(
{"aggs": variant_state_aggs}
if variant_state_aggs
else {}
),
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
@@ -313,7 +325,7 @@ class MetricEventsIterator:
def _get_task_metric_events(
self,
task_state: TaskScrollState,
company_id: str,
companies: Mapping[str, str],
iter_count: int,
navigate_earlier: bool,
specific_variants_requested: bool,
@@ -383,10 +395,10 @@ class MetricEventsIterator:
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metric_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
company_id=companies[task_state.task],
event_type=self.event_type,
body=es_req,
)

View File

@@ -1,5 +1,7 @@
from datetime import datetime
from typing import Callable, Tuple, Sequence, Dict
from typing import Callable, Tuple, Sequence, Dict, Optional
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
@@ -24,13 +26,41 @@ class ModelBLL:
raise errors.bad_request.InvalidModelId(**query)
return model
@staticmethod
def assert_exists(
company_id,
model_ids,
only=None,
allow_public=False,
return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
query = Q(id__in=ids)
q = Model.get_many(
company=company_id,
query=query,
allow_public=allow_public,
return_dicts=False,
)
if only:
q = q.only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidModelId(ids=model_ids)
if return_models:
return list(q)
@classmethod
def publish_model(
cls,
model_id: str,
company_id: str,
user_id: str,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
@@ -45,7 +75,7 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, force_publish_task
model.task, company_id, user_id, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res

View File

@@ -11,7 +11,6 @@ from apiserver.utilities.parameter_key_escaper import (
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@@ -42,27 +41,25 @@ class Metadata:
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):

View File

@@ -1,8 +1,7 @@
import itertools
from collections import defaultdict
from datetime import datetime, timedelta
from functools import reduce
from itertools import groupby
from itertools import groupby, chain
from operator import itemgetter
from typing import (
Sequence,
@@ -22,6 +21,7 @@ from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.projects import ProjectChildrenType
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
@@ -29,7 +29,6 @@ from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
from apiserver.database.utils import get_options, get_company_or_none_constraint
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from .sub_projects import (
_reposition_project_with_children,
@@ -41,13 +40,22 @@ from .sub_projects import (
_ids_with_children,
_ids_with_parents,
_get_project_depth,
ProjectsChildren,
)
log = config.logger(__file__)
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
reports_project_name = ".reports"
datasets_project_name = ".datasets"
pipelines_project_name = ".pipelines"
reports_tag = "reports"
dataset_tag = "dataset"
pipeline_tag = "pipeline"
class ProjectBLL:
child_classes = (Task, Model)
@classmethod
def merge_project(
cls, company, source_id: str, destination_id: str
@@ -57,54 +65,53 @@ class ProjectBLL:
Remove the source project
Return the amounts of moved entities and subprojects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
source=source_id
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
source=source_id
)
source = Project.get(company, source_id)
if destination_id:
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
source = Project.get(company, source_id)
if destination_id:
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
else:
destination = None
else:
destination = None
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
if destination:
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
)[source.id]
if destination:
cls.validate_projects_depth(
projects=children,
old_parent_depth=len(source.path) + 1,
new_parent_depth=len(destination.path) + 1,
)
moved_entities = 0
for entity_type in (Task, Model):
moved_entities += entity_type.objects(
company=company,
project=source_id,
system_tags__nin=[EntityVisibility.archived.value],
).update(upsert=False, project=destination_id)
moved_entities = 0
for entity_type in cls.child_classes:
moved_entities += entity_type.objects(
company=company,
project=source_id,
system_tags__nin=[EntityVisibility.archived.value],
).update(upsert=False, project=destination_id)
moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1
moved_sub_projects = 0
for child in Project.objects(company=company, parent=source_id):
_reposition_project_with_children(
project=child,
children=[c for c in children if c.parent == child.id],
parent=destination,
)
moved_sub_projects += 1
affected = {source.id, *(source.path or [])}
source.delete()
affected = {source.id, *(source.path or [])}
source.delete()
if destination:
destination.update(last_update=datetime.utcnow())
affected.update({destination.id, *(destination.path or [])})
if destination:
destination.update(last_update=datetime.utcnow())
affected.update({destination.id, *(destination.path or [])})
return moved_entities, moved_sub_projects, affected
@@ -127,78 +134,76 @@ class ProjectBLL:
it should be writable. The source location should be writable too.
Return the number of moved projects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
Project.get_for_writing(company=project.company, id=old_parent_id)
if old_parent_id
else None
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
Project.get_for_writing(company=project.company, id=old_parent_id)
if old_parent_id
else None
)
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
)
new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
cls.validate_projects_depth(
projects=[project, *children],
old_parent_depth=len(project.path),
new_parent_depth=_get_project_depth(new_location),
if new_parent and (
project_id == new_parent.id or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
new_parent = _ensure_project(company=company, user=user, name=new_location)
new_parent_id = new_parent.id if new_parent else None
if old_parent_id == new_parent_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if new_parent and (
project_id == new_parent.id or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
now = datetime.utcnow()
affected = set()
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
now = datetime.utcnow()
affected = set()
for p in filter(None, (old_parent, new_parent)):
p.update(last_update=now)
affected.update({p.id, *(p.path or [])})
return moved, affected
return moved, affected
@classmethod
def update(cls, company: str, project_id: str, **fields):
with TimingContext("mongo", "projects_update"):
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
new_name = fields.pop("name", None)
if new_name:
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
fields["basename"] = new_name.split("/")[-1]
new_name = fields.pop("name", None)
if new_name:
new_name, new_location = _validate_project_name(new_name)
old_name, old_location = _validate_project_name(project.name)
if new_location != old_location:
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
fields["name"] = new_name
fields["basename"] = new_name.split("/")[-1]
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
fields["last_update"] = datetime.utcnow()
updated = project.update(upsert=False, **fields)
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects(
[project.id], _only=("id", "name", "path")
)[project.id]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
return updated
return updated
@classmethod
def create(
@@ -301,7 +306,7 @@ class ProjectBLL:
"""
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
"""
with TimingContext("mongo", "move_under_project"):
if project_name or project:
project = cls.find_or_create(
user=user,
company=company,
@@ -309,16 +314,17 @@ class ProjectBLL:
project_name=project_name,
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@@ -399,6 +405,18 @@ class ProjectBLL:
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
{
"$not": {
"$in": [
"$status",
[
TaskStatus.queued,
TaskStatus.in_progress,
TaskStatus.failed,
],
]
}
},
]
},
"then": 1,
@@ -512,7 +530,7 @@ class ProjectBLL:
def aggregate_project_data(
func: Callable[[T, T], T],
project_ids: Sequence[str],
child_projects: Mapping[str, Sequence[Project]],
child_projects: ProjectsChildren,
data: Mapping[str, T],
) -> Dict[str, T]:
"""
@@ -564,6 +582,136 @@ class ProjectBLL:
for r in Task.aggregate(task_runtime_pipeline)
}
@staticmethod
def _get_projects_children(
project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str],
) -> Tuple[ProjectsChildren, Set[str]]:
child_projects = _get_sub_projects(
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=allowed_ids,
)
return (
child_projects,
{c.id for c in chain.from_iterable(child_projects.values())},
)
@staticmethod
def _get_children_info(
project_ids: Sequence[str], child_projects: ProjectsChildren
) -> dict:
return {
project: sorted(
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
key=itemgetter("name"),
)
for project in project_ids
}
@classmethod
def _get_project_dataset_stats_core(
cls,
company: str,
project_ids: Sequence[str],
project_field: str,
entity_class: Type[AttributedDocument],
include_children: bool = True,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = {}
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids, search_hidden=True, allowed_ids=selected_project_ids,
)
project_ids_with_children |= children_ids
pipeline = [
{
"$match": cls.get_match_conditions(
company=company,
project_ids=list(project_ids_with_children),
filter_=filter_,
users=users,
project_field=project_field,
)
},
{"$project": {project_field: 1, "tags": 1}},
{
"$group": {
"_id": f"${project_field}",
"count": {"$sum": 1},
"tags": {"$push": "$tags"},
}
},
]
res = entity_class.aggregate(pipeline)
project_stats = {
result["_id"]: {
"count": result.get("count", 0),
"tags": set(chain.from_iterable(result.get("tags", []))),
}
for result in res
}
def concat_dataset_stats(a: dict, b: dict) -> dict:
return {
"count": a.get("count", 0) + b.get("count", 0),
"tags": a.get("tags", {}) | b.get("tags", {}),
}
top_project_stats = cls.aggregate_project_data(
func=concat_dataset_stats,
project_ids=project_ids,
child_projects=child_projects,
data=project_stats,
)
for _, stat in top_project_stats.items():
stat["tags"] = sorted(list(stat.get("tags", {})))
empty_stats = {"count": 0, "tags": []}
stats = {
project: {"datasets": top_project_stats.get(project, empty_stats)}
for project in project_ids
}
return stats, cls._get_children_info(project_ids, child_projects)
@classmethod
def get_project_dataset_stats(
cls,
company: str,
project_ids: Sequence[str],
include_children: bool = True,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
filter_ = filter_ or {}
filter_system_tags = filter_.get("system_tags")
if not isinstance(filter_system_tags, list):
filter_system_tags = []
if dataset_tag not in filter_system_tags:
filter_system_tags.append(dataset_tag)
filter_["system_tags"] = filter_system_tags
return cls._get_project_dataset_stats_core(
company=company,
project_ids=project_ids,
project_field="parent",
entity_class=Project,
include_children=include_children,
filter_=filter_,
users=users,
selected_project_ids=selected_project_ids,
)
@classmethod
def get_project_stats(
cls,
@@ -574,24 +722,21 @@ class ProjectBLL:
search_hidden: bool = False,
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
user_active_project_ids: Sequence[str] = None,
selected_project_ids: Sequence[str] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(
child_projects = {}
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids,
_only=("id", "name"),
search_hidden=search_hidden,
allowed_ids=user_active_project_ids,
allowed_ids=selected_project_ids,
)
if include_children
else {}
)
project_ids_with_children = set(project_ids) | {
c.id for c in itertools.chain.from_iterable(child_projects.values())
}
project_ids_with_children |= children_ids
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
company,
project_ids=list(project_ids_with_children),
@@ -695,14 +840,7 @@ class ProjectBLL:
for project in project_ids
}
children = {
project: sorted(
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
key=itemgetter("name"),
)
for project in project_ids
}
return stats, children
return stats, cls._get_children_info(project_ids, child_projects)
@classmethod
def get_active_users(
@@ -716,22 +854,21 @@ class ProjectBLL:
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
projects_query &= Q(id__in=project_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in cls.child_classes:
res |= set(cls_.objects(query).distinct(field="user"))
return res
return res
@classmethod
def get_project_tags(
@@ -741,63 +878,97 @@ class ProjectBLL:
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod
def get_projects_with_active_user(
def get_projects_with_selected_children(
cls,
company: str,
users: Sequence[str],
users: Sequence[str] = None,
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
children_type: ProjectChildrenType = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids where user created any tasks including all the parents of these projects
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
including all the parents of these projects
If project ids are specified then filter the results by these project ids
"""
query = Q(user__in=users)
if not (users or children_type):
raise errors.bad_request.ValidationError(
"Either active users or children_condition should be specified"
)
if allow_public:
query &= get_company_or_none_constraint(company)
query = (
get_company_or_none_constraint(company)
if allow_public
else Q(company=company)
)
if users:
query &= Q(user__in=users)
project_query = None
if children_type == ProjectChildrenType.dataset:
child_queries = {
Project: query
& Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name)
}
elif children_type == ProjectChildrenType.pipeline:
child_queries = {Task: query & Q(system_tags__in=[pipeline_tag])}
elif children_type == ProjectChildrenType.report:
child_queries = {Task: query & Q(system_tags__in=[reports_tag])}
else:
query &= Q(company=company)
project_query = query
child_queries = {entity_cls: query for entity_cls in cls.child_classes}
user_projects_query = query
if project_ids:
ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
if project_query:
project_query &= Q(id__in=ids_with_children)
for child_cls in child_queries:
child_queries[child_cls] &= (
Q(parent__in=ids_with_children)
if child_cls is Project
else Q(project__in=ids_with_children)
)
res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = (
{p.id for p in Project.objects(project_query).only("id")}
if project_query
else set()
)
for cls_, query_ in child_queries.items():
res |= set(
cls_.objects(query_).distinct(
field="parent" if cls_ is Project else "project"
)
)
res = list(res)
if not res:
return res, res
user_active_project_ids = _ids_with_parents(res)
selected_project_ids = _ids_with_parents(res)
filtered_ids = (
list(set(user_active_project_ids) & set(project_ids))
list(set(selected_project_ids) & set(project_ids))
if project_ids
else list(user_active_project_ids)
else list(selected_project_ids)
)
return filtered_ids, user_active_project_ids
return filtered_ids, selected_project_ids
@classmethod
def get_task_parents(
@@ -870,10 +1041,11 @@ class ProjectBLL:
project_ids: Sequence[str],
filter_: Mapping[str, Any],
users: Sequence[str],
project_field: str = "project",
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
project_field: {"$in": project_ids},
}
if users:
conditions["user"] = {"$in": users}
@@ -898,6 +1070,69 @@ class ProjectBLL:
return conditions
@classmethod
def _calc_own_datasets_core(
cls,
company: str,
project_ids: Sequence[str],
project_field: str,
entity_class: Type[AttributedDocument],
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
Returns the amount of hyper datasets per requested project
"""
if not project_ids:
return {}
pipeline = [
{
"$match": cls.get_match_conditions(
company=company,
project_ids=project_ids,
filter_=filter_,
users=users,
project_field=project_field,
)
},
{"$project": {project_field: 1}},
{"$group": {"_id": f"${project_field}", "count": {"$sum": 1}}},
]
datasets = {
data["_id"]: data["count"] for data in entity_class.aggregate(pipeline)
}
return {pid: {"own_datasets": datasets.get(pid, 0)} for pid in project_ids}
@classmethod
def calc_own_datasets(
cls,
company: str,
project_ids: Sequence[str],
filter_: Mapping[str, Any] = None,
users: Sequence[str] = None,
) -> Dict[str, dict]:
"""
Returns the amount of datasets per requested project
"""
filter_ = filter_ or {}
filter_system_tags = filter_.get("system_tags")
if not isinstance(filter_system_tags, list):
filter_system_tags = []
if dataset_tag not in filter_system_tags:
filter_system_tags.append(dataset_tag)
filter_["system_tags"] = filter_system_tags
return cls._calc_own_datasets_core(
company=company,
project_ids=project_ids,
project_field="parent",
entity_class=Project,
filter_=filter_,
users=users,
)
@classmethod
def calc_own_contents(
cls,
@@ -930,10 +1165,9 @@ class ProjectBLL:
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Tuple, Set, Sequence
import attr
@@ -8,17 +9,19 @@ from apiserver.bll.task.task_cleanup import (
collect_debug_image_urls,
collect_plot_image_urls,
TaskUrls,
_schedule_for_delete,
)
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from apiserver.timing_context import TimingContext
from .project_bll import ProjectBLL
from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -39,9 +42,9 @@ def validate_project_delete(company: str, project_id: str):
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
for cls in ProjectBLL.child_classes:
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
for cls in (Task, Model):
for cls in ProjectBLL.child_classes:
query = dict(
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
)
@@ -59,13 +62,22 @@ def validate_project_delete(company: str, project_id: str):
def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
company: str,
user: str,
project_id: str,
force: bool,
delete_contents: bool,
delete_external_artifacts=True,
) -> Tuple[DeleteProjectResult, Set[str]]:
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path", "system_tags")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
)
is_pipeline = "pipeline" in (project.system_tags or [])
project_ids = _ids_with_children([project_id])
if not force:
@@ -88,17 +100,28 @@ def delete_project(
)
if not delete_contents:
with TimingContext("mongo", "update_children"):
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(
project=None
)
res = DeleteProjectResult(disassociated_tasks=updated_count)
disassociated = defaultdict(int)
for cls in ProjectBLL.child_classes:
disassociated[cls] = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_urls = _delete_models(projects=project_ids)
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
scheduled = _schedule_for_delete(
task_id=project_id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=True,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
res = DeleteProjectResult(
deleted_tasks=deleted_tasks,
deleted_models=deleted_models,
@@ -127,9 +150,8 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
return 0, set(), set()
task_ids = {t.id for t in tasks}
with TimingContext("mongo", "delete_tasks_update_children"):
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
event_urls, artifact_urls = set(), set()
for task in tasks:
@@ -144,46 +166,58 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
}
)
event_bll.delete_multi_task_events(company, list(task_ids))
event_bll.delete_multi_task_events(
company, list(task_ids), async_delete=async_events_delete
)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
def _delete_models(
company: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
that reference them to reference None.
"""
with TimingContext("mongo", "delete_models"):
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set()
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set(), set()
model_ids = list({m.id for m in models})
model_ids = list({m.id for m in models})
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
event_urls, model_urls = set(), set()
for m in models:
event_urls.update(collect_debug_image_urls(company, m.id))
event_urls.update(collect_plot_image_urls(company, m.id))
if m.uri:
model_urls.add(m.uri)
urls = {m.uri for m in models if m.uri}
deleted = models.delete()
return deleted, urls
event_bll.delete_multi_task_events(
company, model_ids, async_delete=async_events_delete
)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -14,14 +14,16 @@ def _get_project_depth(project_name: str) -> int:
return len(list(filter(None, project_name.split(name_separator))))
def _validate_project_name(project_name: str) -> Tuple[str, str]:
def _validate_project_name(project_name: str, raise_if_empty=True) -> Tuple[str, str]:
"""
Remove redundant '/' characters. Ensure that the project name is not empty
Return the cleaned up project name and location
"""
name_parts = list(filter(None, project_name.split(name_separator)))
name_parts = [p.strip() for p in project_name.split(name_separator) if p]
if not name_parts:
raise errors.bad_request.InvalidProjectName(name=project_name)
if raise_if_empty:
raise errors.bad_request.InvalidProjectName(name=project_name)
return "", ""
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
@@ -34,7 +36,7 @@ def _ensure_project(
If needed auto-create the project and all the missing projects in the path to it
Return the project
"""
name = name.strip(name_separator)
name, location = _validate_project_name(name, raise_if_empty=False)
if not name:
return None
@@ -43,7 +45,6 @@ def _ensure_project(
return project
now = datetime.utcnow()
name, location = _validate_project_name(name)
project = Project(
id=database.utils.id(),
user=user,
@@ -101,12 +102,15 @@ def _get_writable_project_from_name(
return qs.first()
ProjectsChildren = Mapping[str, Sequence[Project]]
def _get_sub_projects(
project_ids: Sequence[str],
_only: Sequence[str] = ("id", "path"),
search_hidden=True,
allowed_ids: Sequence[str] = None,
) -> Mapping[str, Sequence[Project]]:
) -> ProjectsChildren:
"""
Return the list of child projects of all the levels for the parent project ids
"""
@@ -156,13 +160,17 @@ def _update_subproject_names(
Optionally update the paths
"""
updated = 0
now = datetime.utcnow()
for child in children:
child_suffix = name_separator.join(
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
child.name.split(name_separator)[len(old_name.split(name_separator)):]
)
updates = {"name": name_separator.join((project.name, child_suffix))}
updates = {
"name": name_separator.join((project.name, child_suffix)),
"last_update": now,
}
if update_path:
updates["path"] = project.path + child.path[len(old_path) :]
updates["path"] = project.path + child.path[len(old_path):]
updated += child.update(upsert=False, **updates)
return updated
@@ -177,6 +185,7 @@ def _reposition_project_with_children(
project.name = name_separator.join(
filter(None, (new_location, project.name.split(name_separator)[-1]))
)
project.last_update = datetime.utcnow()
_save_under_parent(project, parent=parent)
moved = 1 + _update_subproject_names(

View File

@@ -3,11 +3,13 @@ from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver import database
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.queue.queue_metrics import QueueMetrics, MetricsRefresher
from apiserver.bll.queue.queue_metrics import QueueMetrics
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@@ -131,7 +133,7 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
@@ -139,16 +141,42 @@ class QueueBLL(object):
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries and not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
if queue.entries:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
from apiserver.bll.task import ChangeStatusRequest
for item in queue.entries:
try:
task = Task.get_for_writing(
company=company_id,
id=item.task,
_only=["id", "status", "enqueue_status", "project"],
)
if not task:
continue
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
f"Failed dequeuing task {item.task} from queue: {queue_id}"
)
queue.delete()
def get_all(
self,
company_id: str,
query_dict: dict,
query: Q = None,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
@@ -158,16 +186,25 @@ class QueueBLL(object):
company=company_id,
parameters=query_dict,
query_dict=query_dict,
query=query,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
for worker in self.worker_bll.get_all(company_id):
if queue_id in worker.queues:
return True
return False
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
query: Q = None,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
@@ -179,6 +216,7 @@ class QueueBLL(object):
res = Queue.get_many_with_join(
company=company_id,
query_dict=query_dict,
query=query,
override_projection=projection,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
@@ -230,16 +268,22 @@ class QueueBLL(object):
return res
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
def get_next_task(
self, company_id: str, queue_id: str, task_id: str = None
) -> Optional[Entry]:
"""
Atomically pop and return the first task from the queue (or None)
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
queue = Queue.objects(
**query, **({"entries__0__task": task_id} if task_id else {})
).modify(pop__entries=-1, upsert=False)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
if not task_id or not Queue.objects(**query).first():
raise errors.bad_request.InvalidQueueId(**query)
return
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
@@ -334,6 +378,3 @@ class QueueBLL(object):
if res is None:
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
return int(res.get("count"))
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)

View File

@@ -14,7 +14,6 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@@ -182,7 +181,7 @@ class QueueMetrics:
"aggs": self._get_dates_agg(interval),
}
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
with translate_errors_context():
res = self._search_company_metrics(company_id, es_req)
if "aggregations" not in res:
@@ -279,12 +278,17 @@ class MetricsRefresher:
@classmethod
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
def start(cls, queue_metrics: QueueMetrics):
def start(cls, queue_metrics: QueueMetrics = None):
if not cls.watch_interval_sec:
return
if not queue_metrics:
from .queue_bll import QueueBLL
queue_metrics = QueueBLL().metrics
sleep(10)
while not ThreadsManager.terminating:
while True:
try:
for queue in Queue.objects():
timestamp = es_factory.get_timestamp_millis()

View File

@@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
from apiserver import database
from apiserver.timing_context import TimingContext
T = TypeVar("T")
@@ -31,20 +30,17 @@ class RedisCacheManager(Generic[T]):
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"

View File

@@ -1,6 +1,6 @@
from datetime import datetime
import operator
from threading import Thread, Lock
from threading import Lock
from time import sleep
import attr
@@ -9,76 +9,83 @@ import psutil
from apiserver.utilities.threads_manager import ThreadsManager
class ResourceMonitor(Thread):
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
stat_threads = ThreadsManager("Statistics")
@classmethod
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
def __init__(self, sample_interval_sec=5):
super(ResourceMonitor, self).__init__(daemon=True)
self.sample_interval_sec = sample_interval_sec
self._lock = Lock()
self._clear()
def _clear(self):
sample = self._get_sample()
self._avg = sample
self._min = sample
self._max = sample
self._clear_time = datetime.utcnow()
self._count = 1
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
@classmethod
def _get_sample(cls) -> Sample:
return cls.Sample(
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
@classmethod
def get_current_sample(cls) -> "Sample":
return cls(
cpu_usage=psutil.cpu_percent(),
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
)
def run(self):
while not ThreadsManager.terminating:
sleep(self.sample_interval_sec)
sample = self._get_sample()
class ResourceMonitor:
class Accumulator:
def __init__(self):
sample = Sample.get_current_sample()
self.avg = sample
self.min = sample
self.max = sample
self.time = datetime.utcnow()
self.count = 1
with self._lock:
self._min = self._min.min(sample)
self._max = self._max.max(sample)
self._avg = self._avg.avg(sample, self._count)
self._count += 1
def add_sample(self, sample: Sample):
self.min = self.min.min(sample)
self.max = self.max.max(sample)
self.avg = self.avg.avg(sample, self.count)
self.count += 1
def get_stats(self) -> dict:
sample_interval_sec = 5
_lock = Lock()
accumulator = Accumulator()
@classmethod
@stat_threads.register("resource_monitor", daemon=True)
def start(cls):
while True:
sleep(cls.sample_interval_sec)
sample = Sample.get_current_sample()
with cls._lock:
cls.accumulator.add_sample(sample)
@classmethod
def get_stats(cls) -> dict:
""" Returns current resource statistics and clears internal resource statistics """
with self._lock:
min_ = attr.asdict(self._min)
max_ = attr.asdict(self._max)
avg = attr.asdict(self._avg)
interval = datetime.utcnow() - self._clear_time
self._clear()
with cls._lock:
min_ = attr.asdict(cls.accumulator.min)
max_ = attr.asdict(cls.accumulator.max)
avg = attr.asdict(cls.accumulator.avg)
interval = datetime.utcnow() - cls.accumulator.time
cls.accumulator = cls.Accumulator()
return {
"interval_sec": interval.total_seconds(),

View File

@@ -21,9 +21,8 @@ from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.json import dumps
from apiserver.utilities.threads_manager import ThreadsManager
from apiserver.version import __version__ as current_version
from .resource_monitor import ResourceMonitor
from .resource_monitor import ResourceMonitor, stat_threads
log = config.logger(__file__)
@@ -31,17 +30,19 @@ worker_bll = WorkerBLL()
class StatisticsReporter:
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
send_queue = queue.Queue()
supported = config.get("apiserver.statistics.supported", True)
@classmethod
def start(cls):
if not cls.supported:
return
ResourceMonitor.start()
cls.start_sender()
cls.start_reporter()
@classmethod
@threads.register("reporter", daemon=True)
@stat_threads.register("reporter", daemon=True)
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
@@ -54,7 +55,7 @@ class StatisticsReporter:
hours=config.get("apiserver.statistics.report_interval_hours", 24)
)
sleep(report_interval.total_seconds())
while not ThreadsManager.terminating:
while True:
try:
for company in Company.objects(
defaults__stats_option__enabled=True
@@ -68,7 +69,7 @@ class StatisticsReporter:
sleep(report_interval.total_seconds())
@classmethod
@threads.register("sender", daemon=True)
@stat_threads.register("sender", daemon=True)
def start_sender(cls):
if not cls.supported:
return
@@ -85,7 +86,7 @@ class StatisticsReporter:
WarningFilter.attach()
while not ThreadsManager.terminating:
while True:
try:
report = cls.send_queue.get()
@@ -111,7 +112,7 @@ class StatisticsReporter:
"uuid": get_server_uuid(),
"queues": {"count": Queue.objects(company=company_id).count()},
"users": {"count": User.objects(company=company_id).count()},
"resources": cls.threads.resource_monitor.get_stats(),
"resources": ResourceMonitor.get_stats(),
"experiments": next(
iter(cls._get_experiments_stats(company_id).values()), {}
),

View File

@@ -0,0 +1,48 @@
from copy import copy
from boltons.cacheutils import cachedproperty
from clearml.backend_config.bucket_config import (
S3BucketConfigurations,
AzureContainerConfigurations,
GSBucketConfigurations,
)
from apiserver.config_repo import config
log = config.logger(__file__)
class StorageBLL:
default_aws_configs: S3BucketConfigurations = None
conf = config.get("services.storage_credentials")
@cachedproperty
def _default_aws_configs(self) -> S3BucketConfigurations:
return S3BucketConfigurations.from_config(self.conf.get("aws.s3"))
@cachedproperty
def _default_azure_configs(self) -> AzureContainerConfigurations:
return AzureContainerConfigurations.from_config(self.conf.get("azure.storage"))
@cachedproperty
def _default_gs_configs(self) -> GSBucketConfigurations:
return GSBucketConfigurations.from_config(self.conf.get("google.storage"))
def get_azure_settings_for_company(
self,
company_id: str,
) -> AzureContainerConfigurations:
return copy(self._default_azure_configs)
def get_gs_settings_for_company(
self,
company_id: str,
) -> GSBucketConfigurations:
return copy(self._default_gs_configs)
def get_aws_settings_for_company(
self,
company_id: str,
) -> S3BucketConfigurations:
return copy(self._default_aws_configs)

View File

@@ -5,7 +5,6 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -49,49 +48,41 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
update_cmds = {
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, update_cmds=update_cmds)
update_cmds = {
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=user_id, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, user_id=user_id, update_cmds=delete_cmds)

View File

@@ -15,7 +15,6 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -64,78 +63,82 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only
)
return update_task(
task,
user_id=user_id,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@classmethod
def edit_params(
cls,
company_id: str,
user_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
force: bool,
) -> int:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only
)
return update_task(
task,
user_id=user_id,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
@@ -191,57 +194,56 @@ class HyperParams:
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline)
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
user_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, update_cmds=update_cmds)
return update_task(task, user_id=user_id, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
cls,
company_id: str,
user_id: str,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, update_cmds=delete_cmds)
return update_task(task, user_id=user_id, update_cmds=delete_cmds)

View File

@@ -39,7 +39,7 @@ class NonResponsiveTasksWatchdog:
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start(cls):
sleep(cls.settings.watch_interval_sec)
while not ThreadsManager.terminating:
while True:
watch_interval = cls.settings.watch_interval_sec
if cls.settings.enabled:
try:

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
from typing import Collection, Sequence, Tuple, Optional, Dict
import six
from mongoengine import Q
@@ -33,9 +33,7 @@ from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@@ -66,11 +64,10 @@ class TaskBLL:
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
with TimingContext("mongo", "task_with_access"):
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@@ -88,15 +85,14 @@ class TaskBLL:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@@ -111,7 +107,7 @@ class TaskBLL:
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
with translate_errors_context():
ids = set(task_ids)
q = Task.get_many(
company=company_id,
@@ -131,16 +127,16 @@ class TaskBLL:
return list(q)
@staticmethod
def create(call: APICall, fields: dict):
identity = call.identity
def create(company: str, user: str, fields: dict):
now = datetime.utcnow()
return Task(
id=create_id(),
user=identity.user,
company=identity.company,
user=user,
company=company,
created=now,
last_update=now,
last_change=now,
last_changed_by=user,
**fields,
)
@@ -260,58 +256,66 @@ class TaskBLL:
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else None
)
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination)
if task.output
else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
def ensure_int_labels(execution: dict) -> dict:
if not execution:
return execution
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
update_project_time(new_task.project)
model_labels = execution.get("model_labels")
if model_labels:
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
return execution
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
else task.id
)
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
name=name or task.name,
comment=comment or task.comment,
parent=parent or parent_task,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=ensure_int_labels(execution_dict),
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
update_project_time(new_task.project)
return new_task, new_project_data
@@ -381,7 +385,7 @@ class TaskBLL:
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
@@ -406,25 +410,90 @@ class TaskBLL:
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_scalar_values is not None:
raw_updates = {}
if last_scalar_events is not None:
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values-to_add}__exists"] = True
task = Task.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
new_metrics = []
for path, value in last_scalar_values:
if path[-1] == "min_value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
elif path[-1] == "max_value":
extra_updates[op_path("max", *path[:-1], "max_value")] = value
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
extra_updates[op_path("set", *path)] = value
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [
condition,
iter_value,
f"${value_iteration_field}",
]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = (
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
)
if max_values:
if (
len(total_metrics) >= max_values
and metric not in total_metrics
):
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
if last_events is not None:
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
def events_per_type(metric_data_: Dict[str, dict]) -> Dict[str, EventStats]:
return {
event_type: EventStats(last_update=event["timestamp"])
for event_type, event in metric_data.items()
for event_type, event in metric_data_.items()
}
metric_stats = {
@@ -435,24 +504,38 @@ class TaskBLL:
}
extra_updates["metric_stats"] = metric_stats
return TaskBLL.set_last_update(
ret = TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
last_update=last_update,
**extra_updates,
)
if ret and raw_updates:
Task.objects(id=task_id).update_one(__raw__=[{"$set": raw_updates}])
return ret
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,
cls,
task: Task,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
):
cls.dequeue(task, company_id)
try:
cls.dequeue(task, company_id)
except errors.bad_request.InvalidQueueOrTaskNotQueued:
# dequeue may fail if the queue was deleted
pass
return ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
).execute(enqueue_status=None)
@classmethod

View File

@@ -1,19 +1,32 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Set, Tuple
import attr
from boltons.iterutils import partition
from boltons.iterutils import partition, bucketize, first
from furl import furl
from mongoengine import NotUniqueError
from pymongo.errors import DuplicateKeyError
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.timing_context import TimingContext
from apiserver.database.model.url_to_delete import (
StorageType,
UrlToDelete,
FileType,
DeletionStatus,
)
from apiserver.database.utils import id as db_id
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -56,25 +69,24 @@ class CleanupResult:
)
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
urls = set()
next_scroll_id = None
with TimingContext("es", "collect_plot_image_urls"):
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
@@ -83,9 +95,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
urls = set()
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_id=task,
after_key=after_key,
company_id=company, task_id=task_or_model, after_key=after_key,
)
urls.update(res)
if not after_key:
@@ -94,12 +104,87 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
return urls
supported_storage_types = {
"https://": StorageType.fileserver,
"http://": StorageType.fileserver,
"s3://": StorageType.s3,
"azure://": StorageType.azure,
"gs://": StorageType.gs,
}
def _schedule_for_delete(
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
) -> Set[str]:
urls_per_storage = bucketize(
urls,
key=lambda u: first(
type_
for prefix, type_ in supported_storage_types.items()
if u.startswith(prefix)
),
)
urls_per_storage.pop(None, None)
processed_urls = set()
for storage_type, storage_urls in urls_per_storage.items():
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
scheduled_to_delete = set()
for url in storage_urls:
folder = None
if delete_folders:
try:
parsed = furl(url)
if parsed.path and len(parsed.path.segments) > 1:
folder = parsed.remove(
args=True, fragment=True, path=parsed.path.segments[-1]
).url.rstrip("/")
except Exception as ex:
pass
to_delete = folder or url
if to_delete in scheduled_to_delete:
processed_urls.add(url)
continue
try:
UrlToDelete(
id=db_id(),
company=company,
user=user,
url=to_delete,
task=task_id,
created=datetime.utcnow(),
storage_type=storage_type,
type=FileType.folder if folder else FileType.file,
).save()
except (DuplicateKeyError, NotUniqueError):
existing = UrlToDelete.objects(company=company, url=to_delete).first()
if existing:
existing.update(
user=user,
task=task_id,
created=datetime.utcnow(),
retry_count=0,
unset__last_failure_time=1,
unset__last_failure_reason=1,
status=DeletionStatus.created,
)
processed_urls.add(url)
scheduled_to_delete.add(to_delete)
return processed_urls
def cleanup_task(
company: str,
user: str,
task: Task,
force: bool = False,
update_children=True,
return_file_urls=False,
delete_output_models=True,
delete_external_artifacts=True,
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
@@ -110,9 +195,11 @@ def cleanup_task(
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
task, force
)
delete_external_artifacts = delete_external_artifacts and config.get(
"services.async_urls_delete.enabled", False
)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls:
if return_file_urls or delete_external_artifacts:
event_urls = collect_debug_image_urls(task.company, task.id)
event_urls.update(collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
@@ -128,10 +215,7 @@ def cleanup_task(
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
)
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
deleted_models = 0
updated_models = 0
@@ -139,9 +223,23 @@ def cleanup_task(
if not models:
continue
if delete_output_models and allow_delete:
deleted_models += Model.objects(
id__in=[m.id for m in models if m.id not in in_use_model_ids]
).delete()
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
for m_id in model_ids:
if return_file_urls or delete_external_artifacts:
event_urls.update(collect_debug_image_urls(task.company, m_id))
event_urls.update(collect_plot_image_urls(task.company, m_id))
try:
event_bll.delete_task_events(
task.company,
m_id,
allow_locked=True,
model=True,
async_delete=async_events_delete,
)
except errors.bad_request.InvalidModelId as ex:
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
deleted_models += Model.objects(id__in=list(model_ids)).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
continue
@@ -153,7 +251,20 @@ def cleanup_task(
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
event_bll.delete_task_events(
task.company, task.id, allow_locked=force, async_delete=async_events_delete
)
if delete_external_artifacts:
scheduled = _schedule_for_delete(
task_id=task.id,
company=company,
user=user,
urls=event_urls | model_urls | artifact_urls,
can_delete_folders=not in_use_model_ids and not published_models,
)
for urls in (event_urls, model_urls, artifact_urls):
urls.difference_update(scheduled)
return CleanupResult(
deleted_models=deleted_models,
@@ -173,16 +284,15 @@ def verify_task_children_and_ouptuts(
task, force: bool
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(

View File

@@ -30,7 +30,11 @@ queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
task: Union[str, Task],
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> int:
"""
Deque and archive task
@@ -52,7 +56,11 @@ def archive_task(
)
try:
TaskBLL.dequeue_and_change_status(
task, company_id, status_message, status_reason,
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
@@ -63,11 +71,12 @@ def archive_task(
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
def unarchive_task(
task: str, company_id: str, status_message: str, status_reason: str,
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
@@ -80,11 +89,16 @@ def unarchive_task(
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
def dequeue_task(
task_id: str, company_id: str, status_message: str, status_reason: str,
task_id: str,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
@@ -92,7 +106,11 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status(
task, company_id, status_message=status_message, status_reason=status_reason,
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
)
return 1, res
@@ -100,6 +118,7 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
queue_id: str,
status_message: str,
status_reason: str,
@@ -139,6 +158,7 @@ def enqueue_task(
status_message=status_message,
allow_same_state_transition=False,
force=force,
user_id=user_id,
).execute(enqueue_status=task.status)
try:
@@ -151,6 +171,7 @@ def enqueue_task(
new_status=task.status,
force=True,
status_reason="failed enqueueing",
user_id=user_id,
).execute(enqueue_status=None)
raise
@@ -191,12 +212,14 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
def delete_task(
task_id: str,
company_id: str,
user_id: str,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
@@ -218,6 +241,7 @@ def delete_task(
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
)
@@ -226,10 +250,13 @@ def delete_task(
pass
cleanup_res = cleanup_task(
task,
company=company_id,
user=user_id,
task=task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
@@ -246,10 +273,12 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
@@ -268,16 +297,20 @@ def reset_task(
pass
cleaned_up = cleanup_task(
task,
company=company_id,
user=user_id,
task=task,
force=force,
update_children=False,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__unique_metrics=[],
set__metric_stats={},
set__models__output=[],
set__runtime={},
@@ -308,6 +341,7 @@ def reset_task(
force=force,
status_reason="reset",
status_message="reset",
user_id=user_id,
).execute(
started=None,
completed=None,
@@ -323,8 +357,9 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
force: bool,
publish_model_func: Callable[[str, str], Any] = None,
publish_model_func: Callable[[str, str, str], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
@@ -352,7 +387,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id)
publish_model_func(model.id, company_id, user_id)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@@ -361,6 +396,7 @@ def publish_task(
force=force,
status_reason=status_reason,
status_message=status_message,
user_id=user_id,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
@@ -373,7 +409,12 @@ def publish_task(
def stop_task(
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
task_id: str,
company_id: str,
user_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
@@ -435,4 +476,5 @@ def stop_task(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute()

View File

@@ -9,7 +9,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -27,6 +26,7 @@ class ChangeStatusRequest(object):
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None)
user_id = attr.ib(type=str, default=None)
def execute(self, **kwargs):
current_status = self.current_status_override or self.task.status
@@ -45,6 +45,7 @@ class ChangeStatusRequest(object):
status_changed=now,
last_update=now,
last_change=now,
last_changed_by=self.user_id,
)
if self.new_status == TaskStatus.queued:
@@ -55,7 +56,7 @@ class ChangeStatusRequest(object):
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context(), TimingContext("mongo", "task_status"):
with translate_errors_context():
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)
@@ -166,7 +167,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
@@ -188,9 +189,9 @@ def get_task_for_update(
return task
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
now = datetime.utcnow()
last_updates = dict(last_change=now)
last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from apiserver.apierrors import errors
from apiserver.apimodels.users import CreateRequest
from apiserver.database.errors import translate_errors_context
@@ -12,7 +14,7 @@ class UserBLL:
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
user = User(**request.to_struct())
user = User(**request.to_struct(), created=datetime.utcnow())
user.save(force_insert=True)
@staticmethod

View File

@@ -1,9 +1,11 @@
import itertools
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
@@ -25,7 +27,6 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from .stats import WorkerStats
@@ -51,6 +52,7 @@ class WorkerBLL:
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@@ -95,9 +97,10 @@ class WorkerBLL:
register_timeout=timeout,
last_activity_time=now,
tags=tags,
system_tags=system_tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
self._save_worker_data(entry)
return entry
@@ -109,15 +112,20 @@ class WorkerBLL:
:param worker: worker ID
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
"""
with TimingContext("redis", "workers_unregister"):
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res and not config.get("apiserver.workers.auto_unregister", False):
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
self,
company_id: str,
user_id: str,
ip: str,
report: StatusReportRequest,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
@@ -138,6 +146,8 @@ class WorkerBLL:
if tags is not None:
entry.tags = tags
if system_tags is not None:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
@@ -176,7 +186,9 @@ class WorkerBLL:
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.project = IdNameEntry(
id=project.id, name=project.name
)
entry.last_report_time = now
except APIError:
@@ -193,6 +205,7 @@ class WorkerBLL:
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@@ -201,7 +214,7 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id)
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@@ -213,26 +226,25 @@ class WorkerBLL:
if w.last_activity_time.replace(tzinfo=None) >= ref_time
]
if tags:
include = {t for t in tags if not t.startswith("-")}
exclude = {t[1:] for t in tags if t.startswith("-")}
workers = [
w
for w in workers
if (not include or any(t in include for t in w.tags))
and (not exclude or all(t not in exclude for t in w.tags))
]
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int, tags: Sequence[str] = None
self,
company_id: str,
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags),
self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
)
)
@@ -323,8 +335,7 @@ class WorkerBLL:
"""
key = self._get_worker_key(company_id, user_id, worker)
with TimingContext("redis", "get_worker"):
data = self.redis.get(key)
data = self.redis.get(key)
if data:
try:
@@ -351,27 +362,117 @@ class WorkerBLL:
raise bad_request.InvalidWorkerId(worker=worker)
@staticmethod
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers.{tags_field}_{company}_{tag}"
@staticmethod
def _get_all_workers_key(company: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers_{company}"
def _save_worker_data(self, entry: WorkerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
company_id = entry.company.id
expiration = int(time()) + entry.register_timeout
worker_item = {entry.key: expiration}
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
for tags, tags_field in (
(entry.tags, "tags"),
(entry.system_tags, "systemtags"),
):
for tag in tags:
name = self._get_tagged_workers_key(company_id, tags_field, tag)
self.redis.zadd(name, worker_item)
def _save_worker(self, entry: WorkerEntry) -> None:
"""Save worker entry in Redis"""
try:
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
self._save_worker_data(entry)
except Exception:
msg = "Failed saving worker entry"
log.exception(msg)
def _get(
self, company: str, user: str = "*", worker_id: str = "*"
self,
company: str,
user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags:
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(
all_workers_key, min=0, max=timestamp
)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
entries = []
match = self._get_worker_key(company, user, worker_id)
with TimingContext("redis", "workers_get_all"):
for r in self.redis.scan_iter(match):
data = self.redis.get(r)
if data:
entries.append(WorkerEntry.from_json(data))
for key in worker_keys:
data = self.redis.get(key)
if data:
entries.append(WorkerEntry.from_json(data))
return entries

View File

@@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@@ -126,7 +125,7 @@ class WorkerStats:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
@@ -223,9 +222,7 @@ class WorkerStats:
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext(
"es", "get_worker_activity_report"
):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
if "aggregations" not in data:

View File

@@ -41,10 +41,6 @@
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
strict: false
aggregate {
allow_disk_use: true
}
}
elastic {
@@ -117,6 +113,10 @@
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600
# Timeout in seconds for worker registration (or status report). If a worker did not report for this long,
# it is discarded from the server's table
default_timeout: 600
}
check_for_updates {

View File

@@ -1,3 +1,5 @@
fileserver = "http://localhost:8081"
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]

View File

@@ -2,3 +2,8 @@ max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600
allow_disk_use {
sort: true
aggregate: true
}

View File

@@ -0,0 +1,12 @@
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
# otherwise they are returned to a client for the client side delete
enabled: true
max_retries: 3
retry_timeout_sec: 60
fileserver {
# fileserver url prefixes. Evaluated in the order of priority
# Can be in the form <schema>://host:port/path or /path
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
timeout_sec: 300
}

View File

@@ -39,4 +39,7 @@ events_retrieval {
validate_plot_str: false
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
plot_compression_threshold: 100000
plot_compression_threshold: 100000
# async events delete threshold
max_async_deleted_events_per_sec: 1000

View File

@@ -2,4 +2,7 @@
metrics_before_from_date: 3600
# interval in seconds to update queue metrics. Put 0 to disable
metrics_refresh_interval_sec: 300
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
# or search_hidden is set
hidden_tags: [k8s-glue]
}

View File

@@ -0,0 +1,53 @@
aws {
s3 {
# S3 credentials, used for read/write access by various SDK elements
# default, used for any bucket not specified below
key: ""
secret: ""
region: ""
use_credentials_chain: false
# Additional ExtraArgs passed to boto3 when uploading files. Can also be set per-bucket under "credentials".
extra_args: {}
credentials: [
# specifies key/secret credentials to use when handling s3 urls (read or write)
# {
# bucket: "my-bucket-name"
# key: "my-access-key"
# secret: "my-secret-key"
# },
{
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
host: "localhost:9000"
key: "evg_user"
secret: "evg_pass"
multipart: false
secure: false
}
]
}
}
google.storage {
# Default project and credentials file
# Will be used when no bucket configuration is found
// project: "clearml"
// credentials_json: "/path/to/credentials.json"
//
// # Specific credentials per bucket and sub directory
// credentials = [
// {
// bucket: "my-bucket"
// subdir: "path/in/bucket" # Not required
// project: "clearml"
// credentials_json: "/path/to/credentials.json"
// },
// ]
}
azure.storage {
# containers: [
# {
# account_name: "clearml"
# account_key: "secret"
# # container_name:
# }
# ]
}

View File

@@ -19,4 +19,11 @@ hyperparam_values {
# cache ttl sec
cache_ttl_sec: 86400
}
}
# the maximum amount of unique last metrics/variants combinations
# for which the last values are stored in a task
max_last_metrics: 2000
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
async_events_delete: false

View File

@@ -29,6 +29,9 @@ OVERRIDE_PORT_ENV_KEY = (
)
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
class DatabaseEntry(models.Base):
@@ -52,14 +55,23 @@ class DatabaseFactory:
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
if override_connection_string:
log.info(f"Using override mongodb connection string {override_connection_string}")
log.info(f"Using override mongodb connection string template {override_connection_string}")
else:
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
if override_port:
log.info(f"Using override mongodb port {override_port}")
if override_username:
log.info(f"Using override mongodb username {override_username}")
if override_password:
log.info(f"Using override mongodb password ******")
if override_query:
log.info(f"Using override mongodb query {override_query}")
for key, alias in get_items(Database).items():
if key not in db_entries:
@@ -69,12 +81,20 @@ class DatabaseFactory:
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_connection_string:
entry.host = override_connection_string
con_str = f"{override_connection_string.rstrip('/')}/{key}"
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
entry.host = con_str
else:
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_username:
entry.host = furl(entry.host).set(username=override_username).url
if override_password:
entry.host = furl(entry.host).set(password=override_password).url
if override_query:
entry.host = furl(entry.host).set(query=override_query).url
try:
entry.validate()

View File

@@ -17,16 +17,16 @@ from typing import (
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors
from apiserver.apierrors import errors, APIError
from apiserver.apierrors.base import BaseError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.projection import ProjectionHelper
from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import (
@@ -36,9 +36,10 @@ from apiserver.database.utils import (
field_exists,
)
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
log = config.logger("dbmodel")
mongo_conf = config.get("services._mongo")
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
@@ -55,17 +56,25 @@ class ProperDictMixin(object):
strip_private=True,
only=None,
extra_dict=None,
exclude=None,
) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
only=only,
extra_dict=extra_dict,
exclude=exclude,
)
@classmethod
def properize_dict(
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
cls,
d,
strip_private=True,
only=None,
extra_dict=None,
exclude=None,
normalize_id=True,
):
res = d
if normalize_id and "_id" in res:
@@ -76,6 +85,9 @@ class ProperDictMixin(object):
res = project_dict(res, only)
if extra_dict:
res.update(extra_dict)
if exclude:
exclude_fields_from_dict(res, exclude)
return res
@@ -136,20 +148,30 @@ class GetMixin(PropsMixin):
"or": (default_mongo_op, True),
}
def __init__(self, legacy=False):
def __init__(self, field, legacy=False):
self._field = field
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
try:
op = (
v[len(self.op_prefix) :]
if v and v.startswith(self.op_prefix)
else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
except AttributeError:
raise errors.bad_request.FieldsValueError(
"invalid value type, string expected",
field=self._field,
value=str(v),
)
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
@@ -215,8 +237,8 @@ class GetMixin(PropsMixin):
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
expiration_interval=mongo_conf.get(
"scroll_state_expiration_seconds", 600
),
)
@@ -335,84 +357,108 @@ class GetMixin(PropsMixin):
parameters_options = parameters_options or cls.get_all_query_options
dict_query = {}
query = RegexQ()
if parameters:
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
field = None
# noinspection PyBroadException
try:
if parameters:
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.range_fields, parameters=parameters
).items():
query &= cls.get_range_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.range_fields, parameters=parameters
).items():
query &= cls.get_range_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
for field, data in cls._pop_matching_params(
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
# date time fields also support simplified range queries. Check if this is the case
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else:
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
if field not in keys:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = field if not modifier else "__".join((field, modifier))
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
if field not in keys:
continue
try:
data = cls.MultiFieldParameters(**value)
except Exception:
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
data = cls.MultiFieldParameters(**value)
except Exception:
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
),
data.fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
query = query & q
data.fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
query = query & q
except APIError:
raise
except Exception as ex:
raise errors.bad_request.FieldsValueError(
"failed parsing query field",
error=str(ex),
**({"field": field} if field else {}),
)
return query & RegexQ(**dict_query)
@@ -461,7 +507,7 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)):
data = [data]
helper = cls.ListFieldBucketHelper(legacy=True)
helper = cls.ListFieldBucketHelper(field, legacy=True)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
@@ -497,6 +543,8 @@ class GetMixin(PropsMixin):
def validate_order_by(cls, parameters, search_text) -> Sequence:
"""
Validate and extract order_by params as a list
If ordering is specified then make sure that id field is part of it
This guarantees the unique order when paging
"""
order_by = parameters.get(cls._ordering_key)
if not order_by:
@@ -509,6 +557,9 @@ class GetMixin(PropsMixin):
"text score cannot be used in order_by when search text is not used"
)
if not any(id_field in order_by for id_field in ("id", "-id")):
order_by.append("id")
return order_by
@classmethod
@@ -525,7 +576,7 @@ class GetMixin(PropsMixin):
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
max_page_size = mongo_conf.get("max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
@@ -835,6 +886,13 @@ class GetMixin(PropsMixin):
return cls._get_many_no_company(query=_query, override_projection=projection)
@staticmethod
def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence):
disk_use_setting = mongo_conf.get("allow_disk_use.sort", None)
if disk_use_setting is not None:
qs = qs.allow_disk_use(disk_use_setting)
return qs.order_by(*order_by)
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
@@ -876,7 +934,7 @@ class GetMixin(PropsMixin):
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = qs.order_by(*order_by)
qs = cls._get_qs_with_ordering(qs, order_by)
if include:
# add projection
@@ -968,7 +1026,7 @@ class GetMixin(PropsMixin):
res = cls._get_queries_for_order_field(query, order_field)
if res:
query_sets = [cls.objects(q) for q in res]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
query_sets = [cls._get_qs_with_ordering(qs, order_by) for qs in query_sets]
if order_field and not override_collation:
override_collation = cls._get_collation_override(order_field)
@@ -991,7 +1049,11 @@ class GetMixin(PropsMixin):
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
return [
obj.to_proper_dict(only=include, exclude=exclude)
for qs in query_sets
for obj in qs
]
# add paging
ret = []
@@ -999,7 +1061,7 @@ class GetMixin(PropsMixin):
for i, qs in enumerate(query_sets):
last_size = len(ret)
ret.extend(
obj.to_proper_dict(only=include)
obj.to_proper_dict(only=include, exclude=exclude)
for obj in (qs.skip(start) if start else qs).limit(size)
)
added = len(ret) - last_size
@@ -1124,7 +1186,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
kwargs.update(
allowDiskUse=allow_disk_use
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
else mongo_conf.get("allow_disk_use.aggregate", True)
)
return cls.objects.aggregate(pipeline, **kwargs)

View File

@@ -36,6 +36,7 @@ class Model(AttributedDocument):
("company", "framework"),
("company", "name"),
("company", "user"),
("company", "uri"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@@ -91,3 +92,6 @@ class Model(AttributedDocument):
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@@ -4,6 +4,7 @@ from mongoengine import (
DynamicField,
LongField,
EmbeddedDocumentField,
IntField,
)
from apiserver.database.fields import SafeMapField
@@ -19,7 +20,9 @@ class MetricEvent(EmbeddedDocument):
variant = StringField(required=True)
value = DynamicField(required=True)
min_value = DynamicField() # for backwards compatibility reasons
min_value_iteration = IntField()
max_value = DynamicField() # for backwards compatibility reasons
max_value_iteration = IntField()
class EventStats(EmbeddedDocument):

View File

@@ -19,6 +19,7 @@ from apiserver.database.fields import (
SafeSortedListField,
EmbeddedDocumentListField,
NullableStringField,
NoneType,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument):
content_size = LongField()
timestamp = LongField()
type_data = EmbeddedDocumentField(ArtifactTypeData)
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
display_data = SafeSortedListField(
ListField(UnionField((int, float, str, NoneType)))
)
class ParamsItem(EmbeddedDocument, ProperDictMixin):
@@ -149,6 +152,7 @@ class TaskType(object):
application = "application"
monitor = "monitor"
controller = "controller"
report = "report"
optimizer = "optimizer"
service = "service"
qc = "qc"
@@ -195,6 +199,7 @@ class Task(AttributedDocument):
"$name",
"$id",
"$comment",
"$report",
"$models.input.model",
"$models.output.model",
"$script.repository",
@@ -205,6 +210,7 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"report": 10,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
@@ -227,7 +233,8 @@ class Task(AttributedDocument):
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment"),
pattern_fields=("name", "comment", "report"),
fields=("execution.queue", "runtime.*", "models.input.model"),
)
id = StringField(primary_key=True)
@@ -241,6 +248,8 @@ class Task(AttributedDocument):
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
report = StringField()
report_assets = ListField(StringField())
created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField()
completed = DateTimeField()
@@ -259,6 +268,7 @@ class Task(AttributedDocument):
last_change = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
@@ -270,6 +280,7 @@ class Task(AttributedDocument):
enqueue_status = StringField(
choices=get_options(TaskStatus), exclude_by_default=True
)
last_changed_by = StringField()
def get_index_company(self) -> str:
"""

View File

@@ -0,0 +1,52 @@
from enum import Enum
from mongoengine import StringField, DateTimeField, IntField, EnumField
from apiserver.database import Database, strict
from apiserver.database.model import AttributedDocument
class StorageType(str, Enum):
fileserver = "fileserver"
s3 = "s3"
azure = "azure"
gs = "gs"
unknown = "unknown"
class FileType(str, Enum):
file = "file"
folder = "folder"
class DeletionStatus(str, Enum):
created = "created"
retrying = "retrying"
failed = "failed"
class UrlToDelete(AttributedDocument):
_field_collation_overrides = {
"url": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
("company", "user", "task"),
("company", "storage_type", "url"),
("status", "retry_count", "storage_type"),
],
}
id = StringField(primary_key=True)
url = StringField(required=True, unique_with="company")
task = StringField(required=True)
created = DateTimeField(required=True)
storage_type = EnumField(StorageType, default=StorageType.unknown)
type = EnumField(FileType, default=FileType.file)
retry_count = IntField(default=0)
last_failure_time = DateTimeField()
last_failure_reason = StringField()
status = EnumField(DeletionStatus, default=DeletionStatus.created)

View File

@@ -1,4 +1,4 @@
from mongoengine import Document, StringField, DynamicField
from mongoengine import Document, StringField, DynamicField, DateTimeField
from apiserver.database import Database, strict
from apiserver.database.model import DbModelMixin
@@ -20,3 +20,4 @@ class User(DbModelMixin, Document):
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = DynamicField(default="", exclude_by_default=True)
created = DateTimeField()

View File

@@ -1,7 +1,7 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
from typing import Sequence, Dict, Callable, Tuple, Any, Type
from typing import Sequence, Dict, Callable
from apiserver.apierrors import errors
from apiserver.database.props import PropsMixin
@@ -9,65 +9,6 @@ from apiserver.database.props import PropsMixin
SEP = "."
def project_dict(data, projection, separator=SEP):
"""
Project partial data from a dictionary into a new dictionary
:param data: Input dictionary
:param projection: List of dictionary paths (each a string with field names separated using a separator)
:param separator: Separator (default is '.')
:return: A new dictionary containing only the projected parts from the original dictionary
"""
assert isinstance(data, dict)
result = {}
def copy_path(path_parts, source, destination):
src, dst = source, destination
try:
for depth, path_part in enumerate(path_parts[:-1]):
src_part = src[path_part]
if isinstance(src_part, dict):
src = src_part
dst = dst.setdefault(path_part, {})
elif isinstance(src_part, (list, tuple)):
if path_part not in dst:
dst[path_part] = [{} for _ in range(len(src_part))]
elif not isinstance(dst[path_part], (list, tuple)):
raise TypeError(
"Incompatible destination type %s for %s (list expected)"
% (type(dst), separator.join(path_parts[: depth + 1]))
)
elif not len(dst[path_part]) == len(src_part):
raise ValueError(
"Destination list length differs from source length for %s"
% separator.join(path_parts[: depth + 1])
)
dst[path_part] = [
copy_path(path_parts[depth + 1 :], s, d)
for s, d in zip(src_part, dst[path_part])
]
return destination
else:
raise TypeError(
"Unsupported projection type %s for %s"
% (type(src), separator.join(path_parts[: depth + 1]))
)
last_part = path_parts[-1]
dst[last_part] = src[last_part]
except KeyError:
# Projection field not in source, no biggie.
pass
return destination
for projection_path in sorted(projection):
copy_path(
path_parts=projection_path.split(separator), source=data, destination=result
)
return result
class _ReferenceProxy(dict):
def __init__(self, id):
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
@@ -108,9 +49,6 @@ class ProjectionHelper(object):
self._ref_projection = None
self._proxy_manager = _ProxyManager()
# Cached dpath paths for each of the result documents
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
self._parse_projection(projection)
def _collect_projection_fields(self, doc_cls, projection):

View File

@@ -1,6 +1,6 @@
import hashlib
from inspect import ismethod, getmembers
from typing import Sequence, Tuple, Set, Optional, Callable, Any
from typing import Sequence, Tuple, Set, Optional, Callable, Any, Mapping
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
@@ -203,18 +203,22 @@ def _names_set(*names: str) -> Set[str]:
return set(names) | set(f"-{name}" for name in names)
system_tag_names = {
_system_tag_names = {
"model": _names_set("active", "archived"),
"project": _names_set("archived", "public", "default"),
"task": _names_set("active", "archived", "development"),
"queue": _names_set("default"),
}
system_tag_prefixes = {"task": _names_set("annotat")}
_system_tag_prefixes = {"task": _names_set("annotat")}
def partition_tags(
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
entity: str,
tags: Sequence[str],
system_tags: Optional[Sequence[str]] = (),
system_tag_names: Mapping = _system_tag_names,
system_tag_prefixes: Mapping = _system_tag_prefixes,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Partition the given tags sequence into system and user-defined tags

View File

@@ -35,6 +35,12 @@
},
"value": {
"type": "float"
},
"company_id": {
"type": "keyword"
},
"model_event": {
"type": "boolean"
}
}
}

View File

@@ -67,7 +67,7 @@ class MissingPasswordForElasticUser(Exception):
class ESFactory:
@classmethod
def connect(cls, cluster_name):
def connect(cls, cluster_name) -> Elasticsearch:
"""
Returns the es client for the cluster.
Connects to the cluster if did not connect previously

View File

@@ -0,0 +1,611 @@
import os
from abc import ABC, ABCMeta, abstractmethod
from argparse import ArgumentParser
from collections import defaultdict
from datetime import datetime, timedelta
from functools import partial
from itertools import chain
from pathlib import Path
from time import sleep
from typing import Sequence, Optional, Tuple, Mapping, TypeVar, Hashable, Generic
from urllib.parse import urlparse
import boto3
import requests
from azure.storage.blob import ContainerClient, PartialBatchErrorException
from boltons.iterutils import bucketize, chunked_iter
from furl import furl
from google.cloud import storage as google_storage
from mongoengine import Q
from mypy_boto3_s3.service_resource import Bucket as AWSBucket
from apiserver.bll.storage import StorageBLL
from apiserver.config_repo import config
from apiserver.database import db
from apiserver.database.model.url_to_delete import UrlToDelete, StorageType, DeletionStatus
log = config.logger(f"JOB-{Path(__file__).name}")
conf = config.get("services.async_urls_delete")
max_retries = conf.get("max_retries", 3)
retry_timeout = timedelta(seconds=conf.get("retry_timeout_sec", 60))
storage_bll = StorageBLL()
def mark_retry_failed(ids: Sequence[str], reason: str):
UrlToDelete.objects(id__in=ids).update(
last_failure_time=datetime.utcnow(),
last_failure_reason=reason,
inc__retry_count=1,
)
UrlToDelete.objects(id__in=ids, retry_count__gte=max_retries).update(
status=DeletionStatus.failed
)
def mark_failed(query: Q, reason: str):
UrlToDelete.objects(query).update(
status=DeletionStatus.failed,
last_failure_time=datetime.utcnow(),
last_failure_reason=reason,
)
def scheme_prefix(scheme: str) -> str:
return str(furl(scheme=scheme, netloc=""))
T = TypeVar("T", bound=Hashable)
class Storage(Generic[T], metaclass=ABCMeta):
class Client(ABC):
@property
@abstractmethod
def chunk_size(self) -> int:
pass
def get_path(self, url: UrlToDelete) -> str:
pass
def delete_many(
self, paths: Sequence[str]
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
pass
@property
@abstractmethod
def name(self) -> str:
pass
def group_urls(
self, urls: Sequence[UrlToDelete]
) -> Mapping[T, Sequence[UrlToDelete]]:
pass
def get_client(self, base: T, urls: Sequence[UrlToDelete]) -> Client:
pass
def delete_urls(urls_query: Q, storage: Storage):
to_delete = list(UrlToDelete.objects(urls_query).order_by("url").limit(10000))
if not to_delete:
return
grouped_urls = storage.group_urls(to_delete)
for base, urls in grouped_urls.items():
if not base:
msg = f"Invalid {storage.name} url or missing {storage.name} configuration for account"
mark_failed(
Q(id__in=[url.id for url in urls]), msg,
)
log.warning(
f"Failed to delete {len(urls)} files from {storage.name} due to: {msg}"
)
continue
try:
client = storage.get_client(base, urls)
except Exception as ex:
failed = [url.id for url in urls]
mark_retry_failed(failed, reason=str(ex))
log.warning(
f"Failed to delete {len(failed)} files from {storage.name} due to: {str(ex)}"
)
continue
for chunk in chunked_iter(urls, client.chunk_size):
paths = []
path_to_id_mapping = defaultdict(list)
ids_to_delete = set()
for url in chunk:
try:
path = client.get_path(url)
except Exception as ex:
err = str(ex)
mark_failed(Q(id=url.id), err)
log.warning(f"Error getting path for {url.url}: {err}")
continue
paths.append(path)
path_to_id_mapping[path].append(url.id)
ids_to_delete.add(url.id)
if not paths:
continue
try:
deleted_paths, errors = client.delete_many(paths)
except Exception as ex:
mark_retry_failed([url.id for url in urls], str(ex))
log.warning(
f"Error deleting {len(paths)} files from {storage.name}: {str(ex)}"
)
continue
failed_ids = set()
for reason, err_paths in errors.items():
error_ids = set(
chain.from_iterable(
path_to_id_mapping.get(p, []) for p in err_paths
)
)
mark_retry_failed(list(error_ids), reason)
log.warning(
f"Failed to delete {len(error_ids)} files from {storage.name} storage due to: {reason}"
)
failed_ids.update(error_ids)
deleted_ids = set(
chain.from_iterable(
path_to_id_mapping.get(p, []) for p in deleted_paths
)
)
if deleted_ids:
UrlToDelete.objects(id__in=list(deleted_ids)).delete()
log.info(
f"{len(deleted_ids)} files deleted from {storage.name} storage"
)
missing_ids = ids_to_delete - deleted_ids - failed_ids
if missing_ids:
mark_retry_failed(list(missing_ids), "Not succeeded")
class FileserverStorage(Storage):
class Client(Storage.Client):
timeout = conf.get("fileserver.timeout_sec", 300)
def __init__(self, session: requests.Session, host: str):
self.session = session
self.delete_url = furl(host).add(path="delete_many").url
@property
def chunk_size(self) -> int:
return 10000
def get_path(self, url: UrlToDelete) -> str:
path = url.url.strip("/")
if not path:
raise ValueError("Empty path")
return path
def delete_many(
self, paths: Sequence[str]
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
res = self.session.post(
url=self.delete_url, json={"files": list(paths)}, timeout=self.timeout
)
res.raise_for_status()
res_data = res.json()
return list(res_data.get("deleted", {})), res_data.get("errors", {})
def __init__(self, company: str, fileserver_host: str = None):
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
self.host = fileserver_host.rstrip("/")
if not self.host:
log.warning(f"Fileserver host not configured")
def _parse_url_prefix(prefix) -> Tuple[str, str]:
url = furl(prefix)
host = f"{url.scheme}://{url.netloc}" if url.scheme else None
return host, str(url.path).rstrip("/")
url_prefixes = [
_parse_url_prefix(p) for p in conf.get("fileserver.url_prefixes", [])
]
if not any(self.host == host for host, _ in url_prefixes):
url_prefixes.append((self.host, ""))
self.url_prefixes = url_prefixes
self.company = company
# @classmethod
# def validate_fileserver_access(cls, fileserver_host: str):
# res = requests.get(
# url=fileserver_host
# )
# res.raise_for_status()
@property
def name(self) -> str:
return "Fileserver"
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
"""
For the url return the base_url containing schema, optional host and bucket name
"""
if not url.url:
return None
try:
parsed = furl(url.url)
url_host = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme else None
url_path = str(parsed.path)
except Exception:
return None
for host, path_prefix in self.url_prefixes:
if host and url_host != host:
continue
if path_prefix and not url_path.startswith(path_prefix + "/"):
continue
url.url = url_path[len(path_prefix or "") :]
return self.host
def group_urls(
self, urls: Sequence[UrlToDelete]
) -> Mapping[str, Sequence[UrlToDelete]]:
return bucketize(urls, key=self._resolve_base_url)
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
host = base
session = requests.session()
res = session.get(url=host, timeout=self.Client.timeout)
res.raise_for_status()
return self.Client(session, host)
class AzureStorage(Storage):
class Client(Storage.Client):
def __init__(self, container: ContainerClient):
self.container = container
@property
def chunk_size(self) -> int:
return 256
def get_path(self, url: UrlToDelete) -> str:
parsed = furl(url.url)
if (
not parsed.path
or not parsed.path.segments
or len(parsed.path.segments) <= 1
):
raise ValueError("No path found following container name")
return os.path.join(*parsed.path.segments[1:])
@staticmethod
def _path_from_request_url(request_url: str) -> str:
try:
return furl(request_url).path.segments[-1]
except Exception:
return request_url
def delete_many(
self, paths: Sequence[str]
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
try:
res = self.container.delete_blobs(*paths)
except PartialBatchErrorException as pex:
deleted = []
errors = defaultdict(list)
for part in pex.parts:
if 300 >= part.status_code >= 200:
deleted.append(self._path_from_request_url(part.request.url))
else:
errors[part.reason].append(
self._path_from_request_url(part.request.url)
)
return deleted, errors
return [self._path_from_request_url(part.request.url) for part in res], {}
def __init__(self, company: str):
self.configs = storage_bll.get_azure_settings_for_company(company)
self.scheme = "azure"
@property
def name(self) -> str:
return "Azure"
def _resolve_base_url(self, url: UrlToDelete) -> Optional[Tuple]:
"""
For the url return the base_url containing schema, optional host and bucket name
"""
try:
parsed = urlparse(url.url)
if parsed.scheme != self.scheme:
return None
azure_conf = self.configs.get_config_by_uri(url.url)
if azure_conf is None:
return None
account_url = parsed.netloc
return account_url, azure_conf.container_name
except Exception as ex:
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
return None
def group_urls(
self, urls: Sequence[UrlToDelete]
) -> Mapping[Tuple, Sequence[UrlToDelete]]:
return bucketize(urls, key=self._resolve_base_url)
def get_client(self, base: Tuple, urls: Sequence[UrlToDelete]) -> Client:
account_url, container_name = base
sample_url = urls[0].url
cfg = self.configs.get_config_by_uri(sample_url)
if not cfg or not cfg.account_name or not cfg.account_key:
raise ValueError(
f"Missing account name or key for Azure Blob Storage "
f"account: {account_url}, container: {container_name}"
)
return self.Client(
ContainerClient(
account_url=account_url,
container_name=cfg.container_name,
credential={
"account_name": cfg.account_name,
"account_key": cfg.account_key,
},
)
)
class AWSStorage(Storage):
class Client(Storage.Client):
def __init__(self, base_url: str, container: AWSBucket):
self.container = container
self.base_url = base_url
@property
def chunk_size(self) -> int:
return 1000
def get_path(self, url: UrlToDelete) -> str:
""" Normalize remote path. Remove any prefix that is already handled by the container """
path = url.url
if path.startswith(self.base_url):
path = path[len(self.base_url) :]
path = path.lstrip("/")
return path
@staticmethod
def _path_from_request_url(request_url: str) -> str:
try:
return furl(request_url).path.segments[-1]
except Exception:
return request_url
def delete_many(
self, paths: Sequence[str]
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
res = self.container.delete_objects(
Delete={"Objects": [{"Key": p} for p in paths]}
)
errors = defaultdict(list)
for err in res.get("Errors", []):
msg = err.get("Message", "")
errors[msg].append(err.get("Key"))
return [d.get("Key") for d in res.get("Deleted", [])], errors
def __init__(self, company: str):
self.configs = storage_bll.get_aws_settings_for_company(company)
self.scheme = "s3"
@property
def name(self) -> str:
return "AWS"
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
"""
For the url return the base_url containing schema, optional host and bucket name
"""
try:
parsed = urlparse(url.url)
if parsed.scheme != self.scheme:
return None
s3_conf = self.configs.get_config_by_uri(url.url)
if s3_conf is None:
return None
s3_bucket = s3_conf.bucket
if not s3_bucket:
parts = Path(parsed.path.strip("/")).parts
if parts:
s3_bucket = parts[0]
return "/".join(filter(None, ("s3:/", s3_conf.host, s3_bucket)))
except Exception as ex:
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
return None
def group_urls(
self, urls: Sequence[UrlToDelete]
) -> Mapping[str, Sequence[UrlToDelete]]:
return bucketize(urls, key=self._resolve_base_url)
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
sample_url = urls[0].url
cfg = self.configs.get_config_by_uri(sample_url)
boto_kwargs = {
"endpoint_url": (("https://" if cfg.secure else "http://") + cfg.host)
if cfg.host
else None,
"use_ssl": cfg.secure,
"verify": cfg.verify,
}
name = base[len(scheme_prefix(self.scheme)) :]
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name
if not cfg.use_credentials_chain:
if not cfg.key or not cfg.secret:
raise ValueError(
f"Missing key or secret for AWS S3 host: {cfg.host}, bucket: {str(bucket_name)}"
)
boto_kwargs["aws_access_key_id"] = cfg.key
boto_kwargs["aws_secret_access_key"] = cfg.secret
if cfg.token:
boto_kwargs["aws_session_token"] = cfg.token
return self.Client(
base, boto3.resource("s3", **boto_kwargs).Bucket(bucket_name)
)
class GoogleCloudStorage(Storage):
class Client(Storage.Client):
def __init__(self, base_url: str, container: google_storage.Bucket):
self.container = container
self.base_url = base_url
@property
def chunk_size(self) -> int:
return 100
def get_path(self, url: UrlToDelete) -> str:
""" Normalize remote path. Remove any prefix that is already handled by the container """
path = url.url
if path.startswith(self.base_url):
path = path[len(self.base_url) :]
path = path.lstrip("/")
return path
def delete_many(
self, paths: Sequence[str]
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
not_found = set()
def error_callback(blob: google_storage.Blob):
not_found.add(blob.name)
self.container.delete_blobs(
[self.container.blob(p) for p in paths], on_error=error_callback,
)
errors = {"Not found": list(not_found)} if not_found else {}
return list(set(paths) - not_found), errors
def __init__(self, company: str):
self.configs = storage_bll.get_gs_settings_for_company(company)
self.scheme = "gs"
@property
def name(self) -> str:
return "Google Storage"
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
"""
For the url return the base_url containing schema, optional host and bucket name
"""
try:
parsed = urlparse(url.url)
if parsed.scheme != self.scheme:
return None
gs_conf = self.configs.get_config_by_uri(url.url)
if gs_conf is None:
return None
return str(furl(scheme=parsed.scheme, netloc=gs_conf.bucket))
except Exception as ex:
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
return None
def group_urls(
self, urls: Sequence[UrlToDelete]
) -> Mapping[str, Sequence[UrlToDelete]]:
return bucketize(urls, key=self._resolve_base_url)
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
sample_url = urls[0].url
cfg = self.configs.get_config_by_uri(sample_url)
if cfg.credentials_json:
from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_file(
cfg.credentials_json
)
else:
credentials = None
bucket_name = base[len(scheme_prefix(self.scheme)) :]
return self.Client(
base,
google_storage.Client(project=cfg.project, credentials=credentials).bucket(
bucket_name
),
)
def run_delete_loop(fileserver_host: str):
storage_helpers = {
StorageType.fileserver: partial(
FileserverStorage, fileserver_host=fileserver_host
),
StorageType.s3: AWSStorage,
StorageType.azure: AzureStorage,
StorageType.gs: GoogleCloudStorage,
}
while True:
now = datetime.utcnow()
urls_query = (
Q(status__ne=DeletionStatus.failed)
& Q(retry_count__lt=max_retries)
& (
Q(last_failure_time__exists=False)
| Q(last_failure_time__lt=now - retry_timeout)
)
)
url_to_delete: UrlToDelete = UrlToDelete.objects(
urls_query & Q(storage_type__in=list(storage_helpers))
).order_by("retry_count").limit(1).first()
if not url_to_delete:
sleep(10)
continue
company = url_to_delete.company
user = url_to_delete.user
storage_type = url_to_delete.storage_type
log.info(
f"Deleting {storage_type} objects for company: {company}, user: {user}"
)
company_storage_urls_query = urls_query & Q(
company=company, storage_type=storage_type,
)
delete_urls(
urls_query=company_storage_urls_query,
storage=storage_helpers[storage_type](company=company),
)
def main():
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"--fileserver-host", "-fh", help="Fileserver host address", type=str,
)
args = parser.parse_args()
db.initialize()
run_delete_loop(args.fileserver_host)
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,10 @@
import importlib.util
from datetime import datetime
from inspect import signature
from logging import Logger
from pathlib import Path
import pymongo.database
from mongoengine.connection import get_db
from packaging.version import Version, parse
@@ -80,8 +82,15 @@ def _apply_migrations(log: Logger):
if not func:
continue
try:
sig = signature(func)
kwargs = {}
if len(sig.parameters) == 2:
name, param = list(sig.parameters.items())[-1]
key = name.replace("_", "-")
if issubclass(param.annotation, pymongo.database.Database) and key in dbs:
kwargs[name] = get_db(key)
log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
func(get_db(alias), **kwargs)
except Exception:
log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError(

View File

@@ -24,7 +24,7 @@ from typing import (
Callable,
)
from urllib.parse import unquote, urlparse
from uuid import uuid4
from uuid import uuid4, UUID, uuid5
from zipfile import ZipFile, ZIP_BZIP2
import mongoengine
@@ -72,6 +72,8 @@ class PrePopulate:
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
flags=re.IGNORECASE,
)
_name_guid_ns = UUID("bda3acc1-e612-506c-bade-80071b6cf039")
_example_id_prefix = "e-"
task_cls: Type[Task]
project_cls: Type[Project]
model_cls: Type[Model]
@@ -691,17 +693,56 @@ class PrePopulate:
continue
yield clean
@staticmethod
def _new_id(_):
return str(uuid4()).replace("-", "")
@classmethod
def _hash_id(cls, name: str):
return str(uuid5(cls._name_guid_ns, name)).replace("-", "")
@classmethod
def _example_id(cls, orig_id: str):
if not orig_id or orig_id.startswith(cls._example_id_prefix):
return orig_id
return cls._example_id_prefix + orig_id
@classmethod
def _private_id(cls, orig_id: str):
if not orig_id or not orig_id.startswith(cls._example_id_prefix):
return orig_id
return orig_id[len(cls._example_id_prefix) :]
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
) -> Mapping[str, str]:
if not metadata or not any(
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
):
return {}
ids = {}
for entity_file in entity_files:
with reader.open(entity_file) as f:
is_project = splitext(entity_file.orig_filename)[0].endswith(".Project")
if metadata.get("new_ids"):
id_func = cls._new_id
elif metadata.get("example_ids"):
id_func = cls._example_id if not is_project else cls._hash_id
elif metadata.get("private_ids"):
id_func = cls._private_id if not is_project else cls._new_id
for item in cls.json_lines(f):
orig_id = json.loads(item).get("_id")
doc = json.loads(item)
orig_id = doc.get("_id")
if orig_id:
ids[orig_id] = str(uuid4()).replace("-", "")
ids[orig_id] = (
id_func(orig_id)
if id_func != cls._hash_id
else id_func(doc.get("name"))
)
return ids
@classmethod
@@ -725,11 +766,7 @@ class PrePopulate:
and fi.orig_filename != cls.metadata_filename
]
metadata = metadata or {}
old_to_new_ids = (
cls._generate_new_ids(reader, entity_files)
if metadata.get("new_ids")
else {}
)
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
tasks = []
for entity_file in entity_files:
with reader.open(entity_file) as f:
@@ -923,9 +960,7 @@ class PrePopulate:
return tasks
@classmethod
def _import_events(
cls, f: IO[bytes], company_id: str, _, task_id: str
):
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
@@ -933,5 +968,5 @@ class PrePopulate:
ev["task"] = task_id
ev["company_id"] = company_id
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
company_id, events=events, worker="", allow_locked=True
)

View File

@@ -53,6 +53,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
name=user_name,
given_name=given_name,
family_name=family_name,
created=datetime.utcnow(),
).save()
return user_id

View File

@@ -0,0 +1,15 @@
from pymongo.collection import Collection
from pymongo.database import Database
def migrate_backend(db: Database, auth_db: Database):
users: Collection = db["user"]
auth_users: Collection = auth_db["user"]
created_field = "created"
for doc in users.find({created_field: {"$exists": False}}):
auth_user = auth_users.find_one({"_id": doc["_id"]}, projection=[created_field])
if not auth_user or created_field not in auth_user:
continue
users.update_one(
{"_id": doc["_id"]}, {"$set": {created_field: auth_user[created_field]}}
)

View File

@@ -0,0 +1,17 @@
import logging as log
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import OperationFailure
def migrate_backend(db: Database):
"""
Drop task text index so that the new one including reports field is created
"""
tasks: Collection = db["task"]
try:
tasks.drop_index("backend-db.task.main_text_index")
except OperationFailure as ex:
log.warning(f"Could not delete task text index due to: {str(ex)}")
pass

View File

@@ -1,7 +1,10 @@
attrs>=19.1.0
attrs>=22.1.0
azure-storage-blob>=12.13.1
bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
boto3-stubs[s3]>=1.24.35
clearml>=1.6.0,<1.7.0
dpath>=1.4.2,<2.0
elasticsearch==7.13.3
fastjsonschema>=2.8
@@ -10,13 +13,15 @@ flask-cors>=3.0.5
flask>=0.12.2
funcsigs==1.0.2
furl>=2.0.0
google-cloud-storage==2.0.0
protobuf==3.19.5
gunicorn>=19.7.1
humanfriendly==4.18
jinja2==2.11.3
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.23.1
mongoengine==0.24.2
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
@@ -24,9 +29,8 @@ pyhocon>=0.3.35
pyjwt>=2.4.0
pymongo[srv]==3.12.0
python-rapidjson>=0.6.3
redis==3.5.3
redis==4.4.4
redis-py-cluster>=2.1.3
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.3,<3
six

View File

@@ -15,6 +15,35 @@ metadata_item {
}
}
}
task_status_enum {
type: string
enum: [
created
queued
in_progress
stopped
published
publishing
closed
failed
completed
unknown
]
}
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
credentials {
type: object
properties {

View File

@@ -0,0 +1,106 @@
scalar_key_enum {
type: string
enum: [
iter
timestamp
iso_time
]
}
metric_variants {
type: object
properties {
metric {
description: The metric name
type: string
}
variants {
type: array
description: The names of the metric variants
items {type: string}
}
}
}
debug_images_response_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
iterations {
type: array
items {
type: object
properties {
iter {
type: integer
description: Iteration number
}
events {
type: array
items {
type: object
description: Debug image event
}
}
}
}
}
}
}
debug_images_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
metrics {
type: array
description: "Debug image events grouped by tasks and iterations"
items {"$ref": "#/definitions/debug_images_response_task_metrics"}
}
}
}
plots_response_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
iterations {
type: array
items {
type: object
properties {
iter {
type: integer
description: Iteration number
}
events {
type: array
items {
type: object
description: Plot event
}
}
}
}
}
}
}
plots_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
metrics {
type: array
description: "Plot events grouped by tasks and iterations"
items {"$ref": "#/definitions/plots_response_task_metrics"}
}
}
}

View File

@@ -0,0 +1,506 @@
include "_common.conf"
task_type_enum {
type: string
enum: [
training
testing
inference
data_processing
application
monitor
controller
optimizer
service
qc
custom
]
}
script {
type: object
properties {
binary {
description: "Binary to use when running the script"
type: string
default: python
}
repository {
description: "Name of the repository where the script is located"
type: string
}
tag {
description: "Repository tag"
type: string
}
branch {
description: "Repository branch id If not provided and tag not provided, default repository branch is used."
type: string
}
version_num {
description: "Version (changeset) number. Optional (default is head version) Unused if tag is provided."
type: string
}
entry_point {
description: "Path to execute within the repository"
type: string
}
working_dir {
description: "Path to the folder from which to run the script Default - root folder of repository"
type: string
}
requirements {
description: "A JSON object containing requirements strings by key"
type: object
}
diff {
description: "Uncommitted changes found in the repository when task was run"
type: string
}
}
}
model_type_enum {
type: string
enum: ["input", "output"]
}
task_model_item {
type: object
required: [ name, model]
properties {
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
}
}
output {
type: object
properties {
destination {
description: "Storage id. This is where output files will be stored."
type: string
}
model {
description: "Model id."
type: string
}
result {
description: "Task result. Values: 'success', 'failure'"
type: string
}
error {
description: "Last error text"
type: string
}
}
}
task_execution_progress_enum {
type: string
enum: [
unknown
running
stopping
stopped
]
}
artifact_type_data {
type: object
properties {
preview {
description: "Description or textual data"
type: string
}
content_type {
description: "System defined raw data content type"
type: string
}
data_hash {
description: "Hash of raw data, without any headers or descriptive parts"
type: string
}
}
}
artifact_mode_enum {
type: string
enum: [
input
output
]
default: output
}
artifact {
type: object
required: [key, type]
properties {
key {
description: "Entry key"
type: string
}
type {
description: "System defined type"
type: string
}
mode {
description: "System defined input/output indication"
"$ref": "#/definitions/artifact_mode_enum"
}
uri {
description: "Raw data location"
type: string
}
content_size {
description: "Raw data length in bytes"
type: integer
}
hash {
description: "Hash of entire raw data"
type: string
}
timestamp {
description: "Epoch time when artifact was created"
type: integer
}
type_data {
description: "Additional fields defined by the system"
"$ref": "#/definitions/artifact_type_data"
}
display_data {
description: "User-defined list of key/value pairs, sorted"
type: array
items {
type: array
items {
type: string # can also be a number... TODO: upgrade the generator
}
}
}
}
}
artifact_id {
type: object
required: [key]
properties {
key {
description: "Entry key"
type: string
}
mode {
description: "System defined input/output indication"
"$ref": "#/definitions/artifact_mode_enum"
}
}
}
task_models {
type: object
properties {
input {
description: "The list of task input models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
output {
description: "The list of task output models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
}
}
execution {
type: object
properties {
queue {
description: "Queue ID where task was queued."
type: string
}
parameters {
description: "Json object containing the Task parameters"
type: object
additionalProperties: true
}
model {
description: "Execution input model ID Not applicable for Register (Import) tasks"
type: string
}
model_desc {
description: "Json object representing the Model descriptors"
type: object
additionalProperties: true
}
model_labels {
description: """Json object representing the ids of the labels in the model.
The keys are the layers' names and the values are the IDs.
Not applicable for Register (Import) tasks.
Mandatory for Training tasks"""
type: object
additionalProperties: { type: integer }
}
framework {
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
type: string
}
docker_cmd {
description: "Command for running docker script for the execution of the task"
type: string
}
artifacts {
description: "Task artifacts"
type: array
items { "$ref": "#/definitions/artifact" }
}
}
}
last_metrics_event {
type: object
properties {
metric {
description: "Metric name"
type: string
}
variant {
description: "Variant name"
type: string
}
value {
description: "Last value reported"
type: number
}
min_value {
description: "Minimum value reported"
type: number
}
min_value_iteration {
description: "The iteration at which the minimum value was reported"
type: integer
}
max_value {
description: "Maximum value reported"
type: number
}
max_value_iteration {
description: "The iteration at which the maximum value was reported"
type: integer
}
}
}
last_metrics_variants {
type: object
description: "Last metric events, one for each variant hash"
additionalProperties {
"$ref": "#/definitions/last_metrics_event"
}
}
params_item {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. The combination of section and name should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
configuration_item {
type: object
properties {
name {
description: "Name of the parameter. Should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
section_params {
description: "Task section params"
type: object
additionalProperties {
"$ref": "#/definitions/params_item"
}
}
task {
type: object
properties {
id {
description: "Task id"
type: string
}
name {
description: "Task Name"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company ID"
type: string
}
type {
description: "Type of task. Values: 'training', 'testing'"
"$ref": "#/definitions/task_type_enum"
}
status {
description: ""
"$ref": "#/definitions/task_status_enum"
}
comment {
description: "Free text comment"
type: string
}
created {
description: "Task creation time (UTC) "
type: string
format: "date-time"
}
started {
description: "Task start time (UTC)"
type: string
format: "date-time"
}
completed {
description: "Task end time (UTC)"
type: string
format: "date-time"
}
active_duration {
description: "Task duration time (seconds)"
type: integer
}
parent {
description: "Parent task id"
type: string
}
project {
description: "Project ID of the project to which this task is assigned"
type: string
}
output {
description: "Task output params"
"$ref": "#/definitions/output"
}
execution {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
}
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
// TODO: will be removed
script {
description: "Script info"
"$ref": "#/definitions/script"
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items { type: string }
}
status_changed {
description: "Last status change time"
type: string
format: "date-time"
}
status_message {
description: "free text string representing info about the status"
type: string
}
status_reason {
description: "Reason for last status change"
type: string
}
published {
description: "Task publish time"
type: string
format: "date-time"
}
last_worker {
description: "ID of last worker that handled the task"
type: string
}
last_worker_report {
description: "Last time a worker reported while working on this task"
type: string
format: "date-time"
}
last_update {
description: "Last time this task was created, edited, changed or events for this task were reported"
type: string
format: "date-time"
}
last_change {
description: "Last time any update was done to the task"
type: string
format: "date-time"
}
last_iteration {
description: "Last iteration reported for this task"
type: integer
}
last_metrics {
description: "Last metric variants (hash to events), one for each metric hash"
type: object
additionalProperties {
"$ref": "#/definitions/last_metrics_variants"
}
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
runtime {
description: "Task runtime mapping"
type: object
additionalProperties: true
}
}
}

View File

@@ -1,17 +1,6 @@
_description : "Provides an API for running tasks to report events collected by the system."
_definitions {
metric_variants {
type: object
metric {
description: The metric name
type: string
}
variants {
type: array
description: The names of the metric variants
items {type: string}
}
}
include "_events_common.conf"
metrics_scalar_event {
description: "Used for reporting scalar metrics during training task"
type: object
@@ -164,14 +153,6 @@ _definitions {
}
}
}
scalar_key_enum {
type: string
enum: [
iter
timestamp
iso_time
]
}
log_level_enum {
type: string
enum: [
@@ -260,90 +241,6 @@ _definitions {
}
}
}
debug_images_response_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
iterations {
type: array
items {
type: object
properties {
iter {
type: integer
description: Iteration number
}
events {
type: array
items {
type: object
description: Debug image event
}
}
}
}
}
}
}
debug_images_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
metrics {
type: array
description: "Debug image events grouped by tasks and iterations"
items {"$ref": "#/definitions/debug_images_response_task_metrics"}
}
}
}
plots_response_task_metrics {
type: object
properties {
task {
type: string
description: Task ID
}
iterations {
type: array
items {
type: object
properties {
iter {
type: integer
description: Iteration number
}
events {
type: array
items {
type: object
description: Plot event
}
}
}
}
}
}
}
plots_response {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
metrics {
type: array
description: "Plot events grouped by tasks and iterations"
items {"$ref": "#/definitions/plots_response_task_metrics"}
}
}
}
debug_image_sample_response {
type: object
properties {
@@ -372,17 +269,18 @@ _definitions {
type: string
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
}
event {
type: object
description: "Plot event"
events {
description: "Plot events"
type: array
items { type: object}
}
min_iteration {
type: integer
description: "minimal valid iteration for the variant"
description: "minimal valid iteration for the metric"
}
max_iteration {
type: integer
description: "maximal valid iteration for the variant"
description: "maximal valid iteration for the metric"
}
}
}
@@ -405,13 +303,27 @@ add {
additionalProperties: true
}
}
"2.22": ${add."2.1"} {
request.properties {
model_event {
type: boolean
description: If set then the event is for a model. Otherwise for a task. Cannot be used with task log events. If used in batch then all the events should be marked the same
default: false
}
allow_locked {
type: boolean
description: Allow adding events to published tasks or models
default: false
}
}
}
}
add_batch {
"2.1" {
description: "Adds a batch of events in a single call (json-lines format, stream-friendly)"
batch_request: {
action: add
version: 1.5
version: 2.1
}
response {
type: object
@@ -422,10 +334,16 @@ add_batch {
}
}
}
"2.22": ${add_batch."2.1"} {
batch_request: {
action: add
version: 2.22
}
}
}
delete_for_task {
"2.1" {
description: "Delete all task event. *This cannot be undone!*"
description: "Delete all task events. *This cannot be undone!*"
request {
type: object
required: [
@@ -454,6 +372,37 @@ delete_for_task {
}
}
}
delete_for_model {
"2.22" {
description: "Delete all model events. *This cannot be undone!*"
request {
type: object
required: [
model
]
properties {
model {
type: string
description: "Model ID"
}
allow_locked {
type: boolean
description: "Allow deleting events even if the model is locked"
default: false
}
}
}
response {
type: object
properties {
deleted {
type: boolean
description: "Number of deleted events"
}
}
}
}
}
debug_images {
"2.1" {
description: "Get all debug images of a task"
@@ -495,7 +444,7 @@ debug_images {
}
total {
type: number
description: "Total number of results available for this query"
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
}
scroll_id {
type: string
@@ -548,6 +497,13 @@ debug_images {
}
}
}
"2.22": ${debug_images."2.14"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
plots {
"2.20" {
@@ -565,7 +521,7 @@ plots {
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
description: "Max number of latest iterations for which to return plots"
}
navigate_earlier {
type: boolean
@@ -583,6 +539,13 @@ plots {
}
response {"$ref": "#/definitions/plots_response"}
}
"2.22": ${plots."2.20"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
}
get_debug_image_sample {
"2.12": {
@@ -626,6 +589,13 @@ get_debug_image_sample {
default: true
}
}
"2.22": ${get_debug_image_sample."2.20"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model debug images. Otherwise task debug images
default: false
}
}
}
next_debug_image_sample {
"2.12": {
@@ -651,13 +621,25 @@ next_debug_image_sample {
}
response {"$ref": "#/definitions/debug_image_sample_response"}
}
"2.22": ${next_debug_image_sample."2.12"} {
request.properties.next_iteration {
type: boolean
default: false
description: If set then navigate to the next/previous iteration
}
request.properties.model_events {
type: boolean
description: If set then the retrieving model debug images. Otherwise task debug images
default: false
}
}
}
get_plot_sample {
"2.20": {
description: "Return the plot per metric and variant for the provided iteration"
description: "Return plots for the provided iteration"
request {
type: object
required: [task, metric, variant]
required: [task, metric]
properties {
task {
description: "Task ID"
@@ -667,10 +649,6 @@ get_plot_sample {
description: "Metric name"
type: string
}
variant {
description: "Metric variant"
type: string
}
iteration {
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
type: integer
@@ -692,10 +670,17 @@ get_plot_sample {
}
response {"$ref": "#/definitions/plot_sample_response"}
}
"2.22": ${get_plot_sample."2.20"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
}
next_plot_sample {
"2.20": {
description: "Get the plot for the next variant for the same iteration or for the next iteration"
description: "Get the plot for the next metric for the same iteration or for the next iteration"
request {
type: object
required: [task, scroll_id]
@@ -710,13 +695,25 @@ next_plot_sample {
}
navigate_earlier {
type: boolean
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration.
Otherwise next variant event from the current iteration or first variant event from the next iteration"""
description: """If set then get the either previous metric events from the current iteration or (if does not exist) the last metric events from the previous iteration.
Otherwise next metric events from the current iteration or first metric events from the next iteration"""
}
}
}
response {"$ref": "#/definitions/plot_sample_response"}
}
"2.22": ${next_plot_sample."2.20"} {
request.properties.next_iteration {
type: boolean
default: false
description: If set then navigate to the next/previous iteration
}
request.properties.model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
}
get_task_metrics{
"2.7": {
@@ -749,6 +746,13 @@ get_task_metrics{
}
}
}
"2.22": ${get_task_metrics."2.7"} {
request.properties.model_events {
type: boolean
description: If set then get metrics from model events. Otherwise from task events
default: false
}
}
}
get_task_log {
"1.5" {
@@ -848,7 +852,7 @@ get_task_log {
}
total {
type: number
description: "Total number of results available for this query"
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
}
scroll_id {
type: string
@@ -902,7 +906,7 @@ get_task_log {
}
total {
type: number
description: "Total number of log events available for this query"
description: "Total number of log events available for this query. In case there are more than 10000 events it is set to 10000"
}
}
}
@@ -957,7 +961,7 @@ get_task_events {
}
total {
type: number
description: "Total number of results available for this query"
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
}
scroll_id {
type: string
@@ -966,6 +970,13 @@ get_task_events {
}
}
}
"2.22": ${get_task_events."2.1"} {
request.properties.model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
}
}
download_task_log {
@@ -1016,7 +1027,7 @@ get_task_plots {
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
description: "Max number of latest iterations for which to return plots"
}
scroll_id {
type: string
@@ -1040,7 +1051,7 @@ get_task_plots {
}
total {
type: number
description: "Total number of results available for this query"
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
}
scroll_id {
type: string
@@ -1067,6 +1078,13 @@ get_task_plots {
default: false
}
}
"2.22": ${get_task_plots."2.16"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_multi_task_plots {
"2.1" {
@@ -1087,7 +1105,7 @@ get_multi_task_plots {
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
description: "Max number of latest iterations for which to return plots"
}
scroll_id {
type: string
@@ -1108,7 +1126,7 @@ get_multi_task_plots {
}
total {
type: number
description: "Total number of results available for this query"
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
}
scroll_id {
type: string
@@ -1124,6 +1142,13 @@ get_multi_task_plots {
default: false
}
}
"2.22": ${get_multi_task_plots."2.16"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@@ -1152,6 +1177,13 @@ get_vector_metrics_and_variants {
}
}
}
"2.22": ${get_vector_metrics_and_variants."2.1"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
vector_metrics_iter_histogram {
"2.1" {
@@ -1190,6 +1222,13 @@ vector_metrics_iter_histogram {
}
}
}
"2.22": ${vector_metrics_iter_histogram."2.1"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
scalar_metrics_iter_histogram {
"2.1" {
@@ -1243,6 +1282,13 @@ scalar_metrics_iter_histogram {
}
}
}
"2.22": ${scalar_metrics_iter_histogram."2.14"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
multi_task_scalar_metrics_iter_histogram {
"2.1" {
@@ -1254,7 +1300,7 @@ multi_task_scalar_metrics_iter_histogram {
]
properties {
tasks {
description: "List of task Task IDs. Maximum amount of tasks is 10"
description: "List of task Task IDs. Maximum amount of tasks is 100"
type: array
items {
type: string
@@ -1282,6 +1328,13 @@ multi_task_scalar_metrics_iter_histogram {
additionalProperties: true
}
}
"2.22": ${multi_task_scalar_metrics_iter_histogram."2.1"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_task_single_value_metrics {
"2.20" {
@@ -1331,6 +1384,13 @@ get_task_single_value_metrics {
}
}
}
"2.22": ${get_task_single_value_metrics."2.20"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_task_latest_scalar_values {
"2.1" {
@@ -1410,6 +1470,13 @@ get_scalar_metrics_and_variants {
}
}
}
"2.22": ${get_scalar_metrics_and_variants."2.1"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
get_scalar_metric_data {
"2.1" {
@@ -1459,6 +1526,13 @@ get_scalar_metric_data {
default: false
}
}
"2.22": ${get_scalar_metric_data."2.16"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
scalar_metrics_iter_raw {
"2.16" {
@@ -1523,6 +1597,13 @@ scalar_metrics_iter_raw {
}
}
}
"2.22": ${scalar_metrics_iter_raw."2.16"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
}
clear_scroll {
"2.18" {
@@ -1578,4 +1659,4 @@ clear_task_log {
}
}
}
}
}

View File

@@ -241,6 +241,15 @@ get_all_ex {
default: false
}
}
"2.23": ${get_all_ex."2.20"} {
request.properties {
allow_public {
description: "Allow public models to be returned in the results"
type: boolean
default: true
}
}
}
}
get_all {
"2.1" {
@@ -320,6 +329,14 @@ get_all {
type: array
items { type: string }
}
last_update {
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
type: array
items {
type: string
pattern: "^(>=|>|<=|<)?.*$"
}
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
@@ -968,7 +985,7 @@ move {
items { type: string }
}
project {
description: "Target project ID. If not provided, `project_name` must be provided."
description: "Target project ID. If not provided, `project_name` must be provided. Use null for the root project"
type: string
}
project_name {

View File

@@ -162,4 +162,38 @@ get_entities_count {
}
}
}
"2.22": ${get_entities_count."2.20"} {
request.properties {
search_hidden {
description: "If set to 'true' then hidden projects and tasks are included in the search results"
type: boolean
default: false
}
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
}
}
"2.23": ${get_entities_count."2.22"} {
request.properties {
reports {
type: object
additionalProperties: true
description: Search criteria for reports
}
allow_public {
description: "Allow public entities to be counted in the results"
type: boolean
default: true
}
}
response.properties {
reports {
type: integer
description: The number of reports matching the criteria
}
}
}
}

View File

@@ -46,11 +46,6 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
description: "User-defined tags"
type: array
@@ -66,12 +61,26 @@ _definitions {
type: string
}
last_update {
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
description: "Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"
type: string
format: "date-time"
}
}
}
stats_datasets {
type: object
properties {
count {
description: Number of datasets
type: integer
}
tags {
description: Dataset tags
type: array
items {type: string}
}
}
}
stats_status_count {
type: object
properties {
@@ -146,6 +155,10 @@ _definitions {
description: "Stats for archived tasks"
"$ref": "#/definitions/stats_status_count"
}
datasets {
description: "Stats for childrent datasets"
"$ref": "#/definitions/stats_datasets"
}
}
}
projects_get_all_response_single {
@@ -181,11 +194,6 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
description: "User-defined tags"
type: array
@@ -200,7 +208,16 @@ _definitions {
description: "The default output destination URL for new tasks under this project"
type: string
}
last_update {
description: "Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"
type: string
format: "date-time"
}
// extra properties
hidden {
description: "Returned if the search_hidden flag was specified in the get_all_ex call and the project is hidden"
type: boolean
}
stats {
description: "Additional project stats"
"$ref": "#/definitions/stats"
@@ -222,6 +239,10 @@ _definitions {
}
}
}
own_datasets {
description: "The amount of datasets/hyperdatasers under this project (without children projects). Returned if 'check_own_contents' flag is set in the request and children_type is set to 'dataset' or 'hyperdataset'"
type: integer
}
own_tasks {
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
type: integer
@@ -616,6 +637,22 @@ get_all_ex {
default: false
}
}
"2.23": ${get_all_ex."2.20"} {
request.properties {
allow_public {
description: "Allow public projects to be returned in the results"
type: boolean
default: true
}
}
}
"2.24": ${get_all_ex."2.23"} {
request.properties.children_type {
description: If specified that only the projects under which the entities of this type can be found will be returned
type: string
enum: [pipeline, report, dataset]
}
}
}
update {
"2.1" {

View File

@@ -152,6 +152,13 @@ get_all_ex {
type: integer
}
}
"2.21": ${get_all_ex."2.20"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden queues are included in the search results"
type: boolean
default: false
}
}
}
get_all {
"2.4" {
@@ -244,6 +251,13 @@ get_all {
type: integer
}
}
"2.21": ${get_all."2.20"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden queues are included in the search results"
type: boolean
default: false
}
}
}
get_default {
"2.4" {
@@ -449,6 +463,33 @@ get_next_task {
}
}
}
"2.14": ${get_next_task."2.4"} {
request.properties.get_task_info {
description: "If set then additional task info is returned"
type: boolean
default: false
}
response.properties.task_info {
description: "Info about the returned task. Returned only if get_task_info is set to True"
type: object
properties {
company {
description: Task company ID
type: string
}
user {
description: ID of the user who created the task
type: string
}
}
}
}
"2.21": ${get_next_task."2.14"} {
request.properties.task {
description: Task company ID
type: string
}
}
}
remove_task {
"2.4" {

View File

@@ -0,0 +1,709 @@
_description: "Provides a management API for reports in the system."
_definitions {
include "_tasks_common.conf"
include "_events_common.conf"
update_response {
type: object
properties {
updated {
description: "Number of reports updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
report_status_enum {
type: string
enum: [
created
published
]
}
report {
type: object
properties {
id {
description: "Report id"
type: string
}
name {
description: "Report Name"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company ID"
type: string
}
status {
description: ""
"$ref": "#/definitions/report_status_enum"
}
comment {
description: "Free text comment"
type: string
}
report {
description: "Report template"
type: string
}
report_assets {
description: "List of the external report assets"
type: array
items { type: string }
}
created {
description: "Report creation time (UTC) "
type: string
format: "date-time"
}
project {
description: "Project ID of the project to which this report is assigned"
type: string
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items { type: string }
}
status_changed {
description: "Last status change time"
type: string
format: "date-time"
}
status_message {
description: "free text string representing info about the status"
type: string
}
status_reason {
description: "Reason for last status change"
type: string
}
published {
description: "Report publish time"
type: string
format: "date-time"
}
last_update {
description: "Last time this report was created, edited, changed"
type: string
format: "date-time"
}
}
}
}
create {
"2.23" {
description: "Create a new report"
request {
type: object
required: [
name
]
properties {
name {
description: "Report name. Unique within the company."
type: string
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
comment {
description: "Free text comment "
type: string
}
report {
description: "Report template"
type: string
}
project {
description: "Project ID of the project to which this report is assigned Must exist[ab]"
type: string
}
}
}
response {
type: object
properties {
id {
description: "ID of the report"
type: string
}
project_id {
description: "ID of the project that the report belongs to"
type: string
}
}
}
}
"2.24": ${create."2.23"} {
request.properties.report_assets {
description: "List of the external report assets"
type: array
items { type: string }
}
}
}
update {
"2.23" {
description: "Create a new report"
request {
type: object
required: [
task
]
properties {
task {
description: "The ID of the report task to update"
type: string
}
name {
description: "Report name. Unique within the company."
type: string
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
comment {
description: "Free text comment "
type: string
}
report {
description: "Report template"
type: string
}
}
}
response: ${_definitions.update_response}
}
"2.24": ${update."2.23"} {
request.properties.report_assets {
description: "List of the external report assets"
type: array
items { type: string }
}
}
}
move {
"2.23" {
description: "Move reports to a project"
request {
type: object
required: [task]
properties {
task {
description: "ID of the report to move"
type: string
}
project {
description: "Target project ID. If not provided, `project_name` must be provided. Use null for the root project"
type: string
}
project_name {
description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided."
type: string
}
}
}
response {
type: object
properties {
project_id: {
description: The ID of the target project
type: string
}
}
}
}
}
publish {
"2.23" {
description: "Publish report"
request {
type: object
required: [
task
]
properties {
task {
description: "The ID of the report task to publish"
type: string
}
comment {
description: "The client message"
type: string
}
}
}
response: ${_definitions.update_response}
}
}
archive {
"2.23" {
description: "Archive report"
request {
type: object
required: [
task
]
properties {
task {
description: "The ID of the report task to archive"
type: string
}
comment {
description: "The client message"
type: string
}
}
}
response {
type: object
properties {
archived {
description: "Number of reports archived (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
unarchive {
"2.23" {
description: "Unarchive report"
request {
type: object
required: [
task
]
properties {
task {
description: "The ID of the report task to unarchive"
type: string
}
comment {
description: "The client message"
type: string
}
}
}
response {
type: object
properties {
unarchived {
description: "Number of reports unarchived (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
//share {
// "999.0" {
// description: "Share or unshare report"
// request {
// type: object
// required: [
// task
// ]
// properties {
// task {
// description: "The ID of the report task to share/unshare"
// type: string
// }
// share {
// description: "If set to 'true' then the report will be shared. Otherwise unshared."
// type: boolean
// default: true
// }
// }
// }
// response {
// type: object
// properties {
// changed {
// description: "Number of changed reports (0 or 1)"
// type: integer
// enum: [0, 1]
// }
// }
// }
// }
//}
delete {
"2.23" {
description: "Delete report"
request {
type: object
required: [
task
]
properties {
task {
description: "The ID of the report task to delete"
type: string
}
force {
description: "If not set then published or unarchived reports cannot be deleted"
type: boolean
default: false
}
}
}
response {
type: object
properties {
deleted {
description: "Number of deleted reports (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
get_task_data {
"2.23" {
description: "Get the tasks data according the passed search criteria + requested events"
request {
type: object
properties {
id {
description: "List of IDs to filter by"
type: array
items { type: string }
}
name {
description: "Get only tasks whose name matches this pattern (python regular expression syntax)"
type: string
}
user {
description: "List of user IDs used to filter results by the task's creating user"
type: array
items { type: string }
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
type {
description: "List of task types. One or more of: 'import', 'annotation', 'training' or 'testing' (case insensitive)"
type: array
items { type: string }
}
tags {
description: "List of task user-defined tags. Use '-' prefix to exclude tags"
type: array
items { type: string }
}
system_tags {
description: "List of task system tags. Use '-' prefix to exclude system tags"
type: array
items { type: string }
}
status {
description: "List of task status."
type: array
items { "$ref": "#/definitions/task_status_enum" }
}
project {
description: "List of project IDs"
type: array
items { type: string }
}
only_fields {
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
parent {
description: "Parent ID"
type: string
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
type: array
items {
type: string
pattern: "^(>=|>|<=|<)?.*$"
}
}
search_text {
description: "Free text search query"
type: string
}
allow_public {
description: "Allow public tasks to be returned in the results"
type: boolean
default: true
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
_any_ {
description: "Multi-field pattern condition (any field matches pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
"input.view.entries.dataset" {
description: "List of input dataset IDs"
type: array
items { type: string }
}
"input.view.entries.version" {
description: "List of input dataset version IDs"
type: array
items { type: string }
}
search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
plots {
type: object
properties {
iters {
type: integer
description: "Max number of latest iterations for which to return plots"
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
debug_images {
type: object
properties {
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
scalar_metrics_iter_histogram {
type: object
properties {
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
type: integer
}
key {
description: """
Histogram x axis to use:
iter - iteration number
iso_time - event time as ISO formatted string
timestamp - event timestamp as milliseconds since epoch
"""
"$ref": "#/definitions/scalar_key_enum"
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
response {
type: object
properties {
tasks {
description: "List of tasks"
type: array
items { "$ref": "#/definitions/task" }
}
plots {
type: object
description: "Plots mapped by metric, variant, task and iteration"
additionalProperties: true
}
debug_images {
type: array
description: "Debug image events grouped by tasks and iterations"
items {"$ref": "#/definitions/debug_images_response_task_metrics"}
}
scalar_metrics_iter_histogram {
type: object
additionalProperties: true
}
}
}
}
}
get_all_ex {
"2.23" {
description: "Get all the company's and public report tasks"
request {
type: object
properties {
id {
description: "List of IDs to filter by"
type: array
items { type: string }
}
name {
description: "Get only reports whose name matches this pattern (python regular expression syntax)"
type: string
}
user {
description: "List of user IDs used to filter results by the reports's creating user"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of reports"
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
tags {
description: "List of report user-defined tags. Use '-' prefix to exclude tags"
type: array
items { type: string }
}
system_tags {
description: "List of report system tags. Use '-' prefix to exclude system tags"
type: array
items { type: string }
}
status {
description: "List of report status."
type: array
items { "$ref": "#/definitions/report_status_enum" }
}
project {
description: "List of project IDs"
type: array
items { type: string }
}
only_fields {
description: "List of report field names (nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
status_changed {
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
type: array
items {
type: string
pattern: "^(>=|>|<=|<)?.*$"
}
}
search_text {
description: "Free text search query"
type: string
}
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
allow_public {
description: "Allow public reports to be returned in the results"
type: boolean
default: true
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
_any_ {
description: "Multi-field pattern condition (any field matches pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
properties {
tasks {
description: "List of report tasks"
type: array
items { "$ref": "#/definitions/report" }
}
scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
}
}
get_tags {
"2.23" {
description: "Get all the user tags used for the company reports"
request {
type: object
additionalProperties: false
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
}
}
}
}

View File

@@ -25,7 +25,7 @@ _references {
}
}
_definitions {
include "_common.conf"
include "_tasks_common.conf"
change_many_request: ${_definitions.batch_operation} {
request {
properties {
@@ -69,374 +69,6 @@ _definitions {
}
}
}
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
model_type_enum {
type: string
enum: ["input", "output"]
}
task_model_item {
type: object
required: [ name, model]
properties {
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
}
}
script {
type: object
properties {
binary {
description: "Binary to use when running the script"
type: string
default: python
}
repository {
description: "Name of the repository where the script is located"
type: string
}
tag {
description: "Repository tag"
type: string
}
branch {
description: "Repository branch id If not provided and tag not provided, default repository branch is used."
type: string
}
version_num {
description: "Version (changeset) number. Optional (default is head version) Unused if tag is provided."
type: string
}
entry_point {
description: "Path to execute within the repository"
type: string
}
working_dir {
description: "Path to the folder from which to run the script Default - root folder of repository"
type: string
}
requirements {
description: "A JSON object containing requirements strings by key"
type: object
}
diff {
description: "Uncommitted changes found in the repository when task was run"
type: string
}
}
}
output {
type: object
properties {
destination {
description: "Storage id. This is where output files will be stored."
type: string
}
model {
description: "Model id."
type: string
}
result {
description: "Task result. Values: 'success', 'failure'"
type: string
}
error {
description: "Last error text"
type: string
}
}
}
task_execution_progress_enum {
type: string
enum: [
unknown
running
stopping
stopped
]
}
output_rois_enum {
type: string
enum: [
all_in_frame
only_filtered
frame_per_roi
]
}
artifact_type_data {
type: object
properties {
preview {
description: "Description or textual data"
type: string
}
content_type {
description: "System defined raw data content type"
type: string
}
data_hash {
description: "Hash of raw data, without any headers or descriptive parts"
type: string
}
}
}
artifact_mode_enum {
type: string
enum: [
input
output
]
default: output
}
artifact {
type: object
required: [key, type]
properties {
key {
description: "Entry key"
type: string
}
type {
description: "System defined type"
type: string
}
mode {
description: "System defined input/output indication"
"$ref": "#/definitions/artifact_mode_enum"
}
uri {
description: "Raw data location"
type: string
}
content_size {
description: "Raw data length in bytes"
type: integer
}
hash {
description: "Hash of entire raw data"
type: string
}
timestamp {
description: "Epoch time when artifact was created"
type: integer
}
type_data {
description: "Additional fields defined by the system"
"$ref": "#/definitions/artifact_type_data"
}
display_data {
description: "User-defined list of key/value pairs, sorted"
type: array
items {
type: array
items {
type: string # can also be a number... TODO: upgrade the generator
}
}
}
}
}
artifact_id {
type: object
required: [key]
properties {
key {
description: "Entry key"
type: string
}
mode {
description: "System defined input/output indication"
"$ref": "#/definitions/artifact_mode_enum"
}
}
}
task_models {
type: object
properties {
input {
description: "The list of task input models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
output {
description: "The list of task output models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
}
}
execution {
type: object
properties {
queue {
description: "Queue ID where task was queued."
type: string
}
parameters {
description: "Json object containing the Task parameters"
type: object
additionalProperties: true
}
model {
description: "Execution input model ID Not applicable for Register (Import) tasks"
type: string
}
model_desc {
description: "Json object representing the Model descriptors"
type: object
additionalProperties: true
}
model_labels {
description: """Json object representing the ids of the labels in the model.
The keys are the layers' names and the values are the IDs.
Not applicable for Register (Import) tasks.
Mandatory for Training tasks"""
type: object
additionalProperties: { type: integer }
}
framework {
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
type: string
}
docker_cmd {
description: "Command for running docker script for the execution of the task"
type: string
}
artifacts {
description: "Task artifacts"
type: array
items { "$ref": "#/definitions/artifact" }
}
}
}
task_status_enum {
type: string
enum: [
created
queued
in_progress
stopped
published
publishing
closed
failed
completed
unknown
]
}
task_type_enum {
type: string
enum: [
training
testing
inference
data_processing
application
monitor
controller
optimizer
service
qc
custom
]
}
last_metrics_event {
type: object
properties {
metric {
description: "Metric name"
type: string
}
variant {
description: "Variant name"
type: string
}
value {
description: "Last value reported"
type: number
}
min_value {
description: "Minimum value reported"
type: number
}
max_value {
description: "Maximum value reported"
type: number
}
}
}
last_metrics_variants {
type: object
description: "Last metric events, one for each variant hash"
additionalProperties {
"$ref": "#/definitions/last_metrics_event"
}
}
params_item {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. The combination of section and name should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
configuration_item {
type: object
properties {
name {
description: "Name of the parameter. Should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
param_key {
type: object
properties {
@@ -450,13 +82,6 @@ _definitions {
}
}
}
section_params {
description: "Task section params"
type: object
additionalProperties {
"$ref": "#/definitions/params_item"
}
}
replace_hyperparams_enum {
type: string
enum: [
@@ -465,165 +90,6 @@ _definitions {
all
]
}
task {
type: object
properties {
id {
description: "Task id"
type: string
}
name {
description: "Task Name"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company ID"
type: string
}
type {
description: "Type of task. Values: 'training', 'testing'"
"$ref": "#/definitions/task_type_enum"
}
status {
description: ""
"$ref": "#/definitions/task_status_enum"
}
comment {
description: "Free text comment"
type: string
}
created {
description: "Task creation time (UTC) "
type: string
format: "date-time"
}
started {
description: "Task start time (UTC)"
type: string
format: "date-time"
}
completed {
description: "Task end time (UTC)"
type: string
format: "date-time"
}
active_duration {
description: "Task duration time (seconds)"
type: integer
}
parent {
description: "Parent task id"
type: string
}
project {
description: "Project ID of the project to which this task is assigned"
type: string
}
output {
description: "Task output params"
"$ref": "#/definitions/output"
}
execution {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
}
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
// TODO: will be removed
script {
description: "Script info"
"$ref": "#/definitions/script"
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items { type: string }
}
status_changed {
description: "Last status change time"
type: string
format: "date-time"
}
status_message {
description: "free text string representing info about the status"
type: string
}
status_reason {
description: "Reason for last status change"
type: string
}
published {
description: "Last status change time"
type: string
format: "date-time"
}
last_worker {
description: "ID of last worker that handled the task"
type: string
}
last_worker_report {
description: "Last time a worker reported while working on this task"
type: string
format: "date-time"
}
last_update {
description: "Last time this task was created, updated, changed or events for this task were reported"
type: string
format: "date-time"
}
last_change {
description: "Last time any update was done to the task"
type: string
format: "date-time"
}
last_iteration {
description: "Last iteration reported for this task"
type: integer
}
last_metrics {
description: "Last metric variants (hash to events), one for each metric hash"
type: object
additionalProperties {
"$ref": "#/definitions/last_metrics_variants"
}
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
runtime {
description: "Task runtime mapping"
type: object
additionalProperties: true
}
}
}
task_urls {
type: object
properties {
@@ -715,6 +181,15 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"2.23": ${get_all_ex."2.15"} {
request.properties {
allow_public {
description: "Allow public tasks to be returned in the results"
type: boolean
default: true
}
}
}
}
get_all {
"2.1" {
@@ -1227,6 +702,10 @@ validate {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
script {
description: "Script info"
"$ref": "#/definitions/script"
}
hyperparams {
description: "Task hyper params per section"
type: object
@@ -1241,10 +720,6 @@ validate {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
}
}
}
response {
@@ -1392,6 +867,10 @@ edit {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
script {
description: "Script info"
"$ref": "#/definitions/script"
}
hyperparams {
description: "Task hyper params per section"
type: object
@@ -1406,10 +885,6 @@ edit {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
}
}
}
response: ${_definitions.update_response}
@@ -1478,6 +953,10 @@ reset {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of this task that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
@@ -1489,6 +968,13 @@ reset {
}
}
}
"2.21": ${reset."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the task from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
reset_many {
"2.13": ${_definitions.batch_operation} {
@@ -1541,6 +1027,13 @@ reset_many {
}
}
}
"2.21": ${reset_many."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the tasks from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
delete_many {
"2.13": ${_definitions.batch_operation} {
@@ -1591,6 +1084,13 @@ delete_many {
}
}
}
"2.21": ${delete_many."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the tasks from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
delete {
"2.1" {
@@ -1644,6 +1144,10 @@ delete {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
delete_output_models {
description: "If set to 'true' then delete output models of this task that are not referenced by other tasks. Default value is 'true'"
type: boolean
}
}
}
response {
@@ -1655,6 +1159,13 @@ delete {
}
}
}
"2.21": ${delete."2.13"} {
request.properties.delete_external_artifacts {
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the task from the fileserver (if configured to do so)"
type: boolean
default: true
}
}
}
archive {
"2.12" {
@@ -1708,12 +1219,12 @@ archive_many {
type: string
}
}
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the task was archived"
type: boolean
}
}
response {
properties {
succeeded.items.properties.archived {
description: "Indicates whether the task was archived"
type: boolean
}
}
}
@@ -1920,6 +1431,17 @@ Fails if the following parameters in the task were not filled:
type: string
}
}
"2.22": ${enqueue."2.19"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {
@@ -1953,6 +1475,17 @@ enqueue_many {
type: string
}
}
"2.22": ${enqueue_many."2.19"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}
dequeue {
"1.5" {
@@ -2465,7 +1998,7 @@ move {
items { type: string }
}
project {
description: "Target project ID. If not provided, `project_name` must be provided."
description: "Target project ID. If not provided, `project_name` must be provided. Use null for the root project"
type: string
}
project_name {

View File

@@ -152,6 +152,15 @@ _definitions {
type: array
items: { type: string }
}
system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
key {
description: "Worker entry key"
type: string
}
}
}
@@ -159,11 +168,11 @@ _definitions {
type: object
properties {
id {
description: "ID"
description: "Worker ID"
type: string
}
name {
description: "Name"
description: "Worker name"
type: string
}
}
@@ -294,6 +303,13 @@ get_all {
items { type: string }
}
}
"2.22": ${get_all."2.20"} {
request.properties.system_tags {
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
}
}
register {
"2.4" {
@@ -328,6 +344,13 @@ register {
properties {}
}
}
"2.22": ${register."2.4"} {
request.properties.system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
}
}
unregister {
"2.4" {
@@ -395,6 +418,13 @@ status_report {
properties {}
}
}
"2.22": ${status_report."2.4"} {
request.properties.system_tags {
description: "New system tags for the worker"
type: array
items: { type: string }
}
}
}
get_metric_keys {
"2.4" {

View File

@@ -6,6 +6,8 @@ from flask_compress import Compress
from flask_cors import CORS
from packaging.version import Version
from apiserver.bll.queue.queue_metrics import MetricsRefresher
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from apiserver.database import db
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config import info
@@ -26,7 +28,6 @@ from apiserver.service_repo import ServiceRepo
from apiserver.sync import distributed_lock
from apiserver.updates import check_updates_thread
from apiserver.utilities.env import get_bool
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@@ -119,6 +120,8 @@ class AppSequence:
def _start_worker(self):
check_updates_thread.start()
StatisticsReporter.start()
MetricsRefresher.start()
NonResponsiveTasksWatchdog.start()
def _on_worker_stop(self):
ThreadsManager.terminating = True
pass

View File

@@ -11,7 +11,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
@@ -88,9 +87,7 @@ def authorize_credentials(auth_data, service, action, call):
query = Q(id=fixed_user.user_id)
with TimingContext("mongo", "user_by_cred"), translate_errors_context(
"authorizing request"
):
with translate_errors_context("authorizing request"):
user = User.objects(query).first()
if not user:
raise errors.unauthorized.InvalidCredentials(
@@ -108,8 +105,7 @@ def authorize_credentials(auth_data, service, action, call):
}
)
with TimingContext("mongo", "company_by_id"):
company = Company.objects(id=user.company).only("id", "name").first()
company = Company.objects(id=user.company).only("id", "name").first()
if not company:
raise errors.unauthorized.InvalidCredentials("invalid user company")

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.20")
_max_version = PartialVersion("2.24")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -2,7 +2,7 @@ import itertools
import math
from collections import defaultdict
from operator import itemgetter
from typing import Sequence, Optional
from typing import Sequence, Optional, Union, Tuple, Mapping
import attr
import jsonmodels.fields
@@ -19,7 +19,6 @@ from apiserver.apimodels.events import (
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
GetHistorySampleRequest,
NextHistorySampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
@@ -28,18 +27,44 @@ from apiserver.apimodels.events import (
ClearScrollRequest,
ClearTaskLogRequest,
SingleValueMetricsRequest,
GetVariantSampleRequest,
GetMetricSamplesRequest,
TaskMetric,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
event_bll = EventBLL()
model_bll = ModelBLL()
def _assert_task_or_model_exists(
company_id: str, task_ids: Union[str, Sequence[str]], model_events: bool
) -> Union[Sequence[Model], Sequence[Task]]:
if model_events:
return model_bll.assert_exists(
company_id,
task_ids,
allow_public=True,
only=("id", "name", "company", "company_origin"),
)
return task_bll.assert_exists(
company_id,
task_ids,
allow_public=True,
only=("id", "name", "company", "company_origin"),
)
@endpoint("events.add")
@@ -47,7 +72,7 @@ def add(call: APICall, company_id, _):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked
company_id, [data], call.worker, allow_locked=allow_locked
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -58,7 +83,12 @@ def add_batch(call: APICall, company_id, _):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
allow_locked=events[0].get("allow_locked", False),
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -225,12 +255,13 @@ def download_task_log(call, company_id, _):
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, EventType.metrics_vector
task_or_model.get_index_company(), task_id, EventType.metrics_vector
)
)
@@ -238,12 +269,13 @@ def get_vector_metrics_and_variants(call, company_id, _):
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, EventType.metrics_scalar
task_or_model.get_index_company(), task_id, EventType.metrics_scalar
)
)
@@ -255,13 +287,14 @@ def get_scalar_metrics_and_variants(call, company_id, _):
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task.get_index_company(), task_id, metric, variant
task_or_model.get_index_company(), task_id, metric, variant
)
call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations
@@ -286,11 +319,10 @@ def make_response(
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
def get_task_events(call, company_id, request: TaskEventsRequest):
def get_task_events(_, company_id, request: TaskEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events,
)[0]
key = ScalarKeyEnum.iter
@@ -322,7 +354,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
metric_variants=metric_variants,
)
@@ -336,7 +368,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
res = event_bll.events_iterator.get_task_events(
event_type=request.event_type,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
batch_size=batch_size,
key=ScalarKeyEnum.iter,
@@ -365,16 +397,17 @@ def get_scalar_metric_data(call, company_id, _):
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_or_model.get_index_company(),
task_id,
event_type=EventType.metrics_scalar,
sort=[{"iter": {"order": "desc"}}],
metric=metric,
metrics={metric: []},
scroll_id=scroll_id,
no_scroll=no_scroll,
)
@@ -398,7 +431,7 @@ def get_task_latest_scalar_values(call, company_id, _):
index_company, task_id
)
last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
company_id=index_company, event_type=EventType.all, task_id=task_id, iters=1
).get(task_id)
call.result.data = dict(
metrics=metrics,
@@ -417,36 +450,60 @@ def get_task_latest_scalar_values(call, company_id, _):
def scalar_metrics_iter_histogram(
call, company_id, request: ScalarMetricsIterHistogramRequest
):
task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company", "company_origin")
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.get_index_company(),
company_id=task_or_model.get_index_company(),
task_id=request.task,
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
call.result.data = metrics
def _get_task_or_model_index_companies(
company_id: str, task_ids: Sequence[str], model_events=False,
) -> TaskCompanies:
"""
Returns lists of tasks grouped by company
"""
tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events,
)
unique_ids = set(task_ids)
if len(tasks_or_models) < len(unique_ids):
invalid = tuple(unique_ids - {t.id for t in tasks_or_models})
error_cls = (
errors.bad_request.InvalidModelId
if model_events
else errors.bad_request.InvalidTaskId
)
raise error_cls(company=company_id, ids=invalid)
return bucketize(tasks_or_models, key=lambda t: t.get_index_company())
@endpoint(
"events.multi_task_scalar_metrics_iter_histogram",
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
)
def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
call, company_id, request: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
task_ids = request.tasks
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id,
task_ids=task_ids,
samples=req_model.samples,
allow_public=True,
key=req_model.key,
companies=_get_task_or_model_index_companies(
company_id, task_ids, request.model_events
),
samples=request.samples,
key=request.key,
)
)
@@ -455,21 +512,11 @@ def multi_task_scalar_metrics_iter_histogram(
def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest
):
task_ids = call.data["tasks"]
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
res = event_bll.metrics.get_task_single_value_metrics(
companies=_get_task_or_model_index_companies(
company_id, request.tasks, request.model_events
),
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids)
call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()]
)
@@ -481,22 +528,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=company_id,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
companies = _get_task_or_model_index_companies(company_id, task_ids)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
next(iter(companies)),
list(companies),
task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
@@ -504,10 +540,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
scroll_id=scroll_id,
)
tasks = {t.id: t.name for t in tasks}
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
result.events, max_iters=iters, task_names=task_names
)
call.result.data = dict(
@@ -518,47 +555,56 @@ def get_multi_task_plots_v1_7(call, company_id, _):
)
def _get_multitask_plots(
companies: TaskCompanies,
last_iters: int,
metrics: MetricVariants = None,
scroll_id=None,
no_scroll=True,
model_events=False,
) -> Tuple[dict, int, str]:
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
result = event_bll.get_task_events(
company_id=list(companies),
task_id=list(task_names),
event_type=EventType.metrics_plot,
metrics=metrics,
last_iter_count=last_iters,
sort=[{"iter": {"order": "desc"}}],
scroll_id=scroll_id,
no_scroll=no_scroll,
)
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=last_iters, task_names=task_names
)
return return_events, result.total_events, result.next_scroll_id
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, req_model):
def get_multi_task_plots(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
companies = _get_task_or_model_index_companies(
company_id, task_ids, model_events=model_events
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.get_task_events(
next(iter(companies)),
task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
return_events, total_events, next_scroll_id = _get_multitask_plots(
companies=companies,
last_iters=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
model_events=model_events,
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
total=total_events,
scroll_id=next_scroll_id,
)
@@ -613,18 +659,14 @@ def _get_metric_variants_from_request(
def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events
)[0]
result = event_bll.get_task_plots(
task.get_index_company(),
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
task_or_model.get_index_company(),
task_id=task_id,
last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
@@ -638,34 +680,43 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
)
def _task_metrics_dict_from_request(req_metrics: Sequence[TaskMetric]) -> dict:
task_metrics = defaultdict(dict)
for tm in req_metrics:
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
return task_metrics
def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEvents]:
return [
MetricEvents(
task=task,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, iterations) in metric_events
]
@endpoint(
"events.plots",
request_data_model=MetricEventsRequest,
response_data_model=MetricEventsResponse,
)
def task_plots(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
tasks = task_bll.assert_exists(
company_id,
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
task_metrics = _task_metrics_dict_from_request(request.metrics)
task_ids = list(task_metrics)
task_or_models = _assert_task_or_model_exists(
company_id, task_ids=task_ids, model_events=request.model_events
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.plots_iterator.get_task_events(
company_id=next(iter(companies)),
companies={t.id: t.get_index_company() for t in task_or_models},
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
@@ -675,16 +726,7 @@ def task_plots(call, company_id, request: MetricEventsRequest):
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, iterations) in result.metric_events
],
metrics=_get_metrics_response(result.metric_events),
)
@@ -730,12 +772,13 @@ def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
tasks_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
tasks_or_model.get_index_company(),
task_id,
event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}],
@@ -761,28 +804,13 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=MetricEventsResponse,
)
def get_debug_images(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
tasks = task_bll.assert_exists(
company_id,
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
task_metrics = _task_metrics_dict_from_request(request.metrics)
task_ids = list(task_metrics)
task_or_models = _assert_task_or_model_exists(
company_id, task_ids=task_ids, model_events=request.model_events
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.debug_images_iterator.get_task_events(
company_id=next(iter(companies)),
companies={t.id: t.get_index_company() for t in task_or_models},
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
@@ -792,30 +820,21 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, iterations) in result.metric_events
],
metrics=_get_metrics_response(result.metric_events),
)
@endpoint(
"events.get_debug_image_sample",
min_version="2.12",
request_data_model=GetHistorySampleRequest,
request_data_model=GetVariantSampleRequest,
)
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task.company,
company_id=task_or_model.get_index_company(),
task=request.task,
metric=request.metric,
variant=request.variant,
@@ -833,30 +852,30 @@ def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
request_data_model=NextHistorySampleRequest,
)
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task.company,
company_id=task_or_model.get_index_company(),
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint(
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
)
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_sample_for_variant(
company_id=task.company,
res = event_bll.plot_sample_history.get_samples_for_metric(
company_id=task_or_model.get_index_company(),
task=request.task,
metric=request.metric,
variant=request.variant,
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
@@ -869,28 +888,28 @@ def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
)
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_next_sample(
company_id=task.company,
company_id=task_or_model.get_index_company(),
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
)
call.result.data = attr.asdict(res, recurse=False)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id,
task_ids=request.tasks,
allow_public=True,
only=("company", "company_origin"),
)[0]
task_or_models = _assert_task_or_model_exists(
company_id, request.tasks, model_events=request.model_events,
)
res = event_bll.metrics.get_task_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
task_or_models[0].get_index_company(),
task_ids=request.tasks,
event_type=request.event_type,
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
@@ -898,7 +917,7 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
def delete_for_task(call, company_id, _):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
@@ -910,6 +929,19 @@ def delete_for_task(call, company_id, req_model):
)
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, model_id, allow_locked=allow_locked, model=True
)
)
@endpoint("events.clear_task_log")
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
@@ -925,7 +957,9 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
def _get_top_iter_unique_events_per_task(
events, max_iters: int, task_names: Mapping[str, str]
):
key = itemgetter("metric", "variant", "task", "iter")
unique_events = itertools.chain.from_iterable(
@@ -938,7 +972,7 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
def collect(evs, fields):
if not fields:
evs = list(evs)
return {"name": tasks.get(evs[0].get("task")), "plots": evs}
return {"name": task_names.get(evs[0].get("task")), "plots": evs}
return {
str(k): collect(group, fields[1:])
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
@@ -1004,17 +1038,15 @@ def scalar_metrics_iter_raw(
request.batch_size = request.batch_size or scroll.request.batch_size
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events
)[0]
metric_variants = _get_metric_variants_from_request([request.metric])
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
metric_variants=metric_variants,
)
@@ -1030,7 +1062,7 @@ def scalar_metrics_iter_raw(
for iteration in range(0, math.ceil(batch_size / 10_000)):
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
company_id=task_or_model.get_index_company(),
task_id=task_id,
batch_size=min(batch_size, 10_000),
navigate_earlier=False,

View File

@@ -24,7 +24,7 @@ from apiserver.apimodels.models import (
)
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
@@ -51,8 +51,8 @@ from apiserver.services.utils import (
ModelsBackwardsCompatibility,
unescape_metadata,
escape_metadata,
process_include_subprojects,
)
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -107,30 +107,18 @@ def get_by_task_id(call: APICall, company_id, _):
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
_process_include_subprojects(call.data)
process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=request.allow_public,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
@@ -139,10 +127,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
for model in models:
model["stats"] = stats.get(model["id"])
@@ -154,10 +139,9 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models}
@@ -167,15 +151,14 @@ def get_by_id_ex(call: APICall, company_id, _):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@@ -414,9 +397,7 @@ def validate_task(company_id, fields: dict):
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
@@ -424,11 +405,7 @@ def edit(call: APICall, company_id, _):
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
if field and isinstance(value, dict) and isinstance(field, EmbeddedDocument):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
@@ -448,13 +425,9 @@ def edit(call: APICall, company_id, _):
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
_update_cached_tags(company_id, project=model.project, fields=fields)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -465,9 +438,7 @@ def edit(call: APICall, company_id, _):
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
@@ -511,6 +482,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
@@ -529,6 +501,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial(
ModelBLL.publish_model,
company_id=company_id,
user_id=call.identity.user,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),
@@ -635,7 +608,7 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
@endpoint("models.move", request_data_model=MoveRequest)
def move(call: APICall, company_id: str, request: MoveRequest):
if not (request.project or request.project_name):
if not ("project" in call.data or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)

View File

@@ -6,14 +6,16 @@ from mongoengine import Q
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
org_bll = OrgBLL()
project_bll = ProjectBLL()
@endpoint("organization.get_tags", request_data_model=TagsRequest)
@@ -49,14 +51,15 @@ def get_user_companies(call: APICall, company_id: str, _):
}
@endpoint("organization.get_entities_count", request_data_model=EntitiesCountRequest)
def get_entities_count(call: APICall, company, _):
@endpoint("organization.get_entities_count")
def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
entity_classes: Mapping[str, Type[AttributedDocument]] = {
"projects": Project,
"tasks": Task,
"models": Model,
"pipelines": Project,
"datasets": Project,
"reports": Task,
}
ret = {}
for field, entity_cls in entity_classes.items():
@@ -64,12 +67,41 @@ def get_entities_count(call: APICall, company, _):
if data is None:
continue
if field == "reports":
data["type"] = TaskType.report
data["include_subprojects"] = True
if request.active_users:
if entity_cls is Project:
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
ids, _ = project_bll.get_projects_with_selected_children(
company=company,
users=request.active_users,
project_ids=requested_ids,
allow_public=request.allow_public,
)
if not ids:
ret[field] = 0
continue
data["id"] = ids
elif not data.get("user"):
data["user"] = request.active_users
query = Q()
if entity_cls in (Project, Task) and not data.get("search_hidden"):
if (
entity_cls in (Project, Task)
and field not in ("reports", "pipelines", "datasets")
and not request.search_hidden
):
query &= Q(system_tags__ne=EntityVisibility.hidden.value)
ret[field] = entity_cls.get_count(
company=company, query_dict=data, query=query, allow_public=True,
company=company,
query_dict=data,
query=query,
allow_public=request.allow_public,
)
call.result.data = ret

View File

@@ -60,6 +60,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",

View File

@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Optional, Tuple
import attr
from mongoengine import Q
@@ -18,9 +18,11 @@ from apiserver.apimodels.projects import (
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
ProjectChildrenType,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_bll import pipeline_tag, reports_tag
from apiserver.bll.project.project_cleanup import (
delete_project,
validate_project_delete,
@@ -28,6 +30,7 @@ from apiserver.bll.project.project_cleanup import (
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import TaskType
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
@@ -39,7 +42,6 @@ from apiserver.services.utils import (
get_tags_filter_dictionary,
sort_tags_response,
)
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
project_bll = ProjectBLL()
@@ -60,11 +62,8 @@ def get_by_id(call):
project_id = call.data["project"]
with translate_errors_context():
with TimingContext("mongo", "projects_by_id"):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
project = Project.objects(query).first()
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@@ -100,80 +99,135 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
data["parent"] = [None]
def _get_project_stats_filter(request: ProjectsGetRequest) -> Tuple[Optional[dict], bool]:
if request.include_stats_filter or not request.children_type:
return request.include_stats_filter, request.search_hidden
if request.children_type == ProjectChildrenType.pipeline:
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}, True
if request.children_type == ProjectChildrenType.report:
return {"system_tags": [reports_tag], "type": [TaskType.report]}, True
return request.include_stats_filter, request.search_hidden
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
data = call.data
conform_tag_fields(call, data)
allow_public = not request.non_public
allow_public = (
data["allow_public"]
if "allow_public" in data
else not data["non_public"]
if "non_public" in data
else request.allow_public
)
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
with TimingContext("mongo", "projects_get_all"):
user_active_project_ids = None
if request.active_users:
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
)
if not ids:
return {"projects": []}
data["id"] = ids
ret_params = {}
projects: Sequence[dict] = Project.get_many_with_join(
selected_project_ids = None
if request.active_users or request.children_type:
ids, selected_project_ids = project_bll.get_projects_with_selected_children(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
users=request.active_users,
project_ids=requested_ids,
allow_public=allow_public,
ret_params=ret_params,
children_type=request.children_type,
)
if not ids:
return {"projects": []}
data["id"] = ids
if request.check_own_contents and requested_ids:
existing_requested_ids = {
project["id"] for project in projects if project["id"] in requested_ids
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
users=request.active_users,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
ret_params = {}
conform_output_tags(call, projects)
if request.include_stats:
project_ids = {project["id"] for project in projects}
stats, children = project_bll.get_project_stats(
remove_system_tags = False
if request.search_hidden:
only_fields = data.get("only_fields")
if isinstance(only_fields, list) and "system_tags" not in only_fields:
only_fields.append("system_tags")
remove_system_tags = True
projects: Sequence[dict] = Project.get_many_with_join(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
allow_public=allow_public,
ret_params=ret_params,
)
if not projects:
return {"projects": projects, **ret_params}
if request.search_hidden:
for p in projects:
system_tags = (
p.pop("system_tags", [])
if remove_system_tags
else p.get("system_tags", [])
)
if EntityVisibility.hidden.value in system_tags:
p["hidden"] = True
conform_output_tags(call, projects)
project_ids = list({project["id"] for project in projects})
if request.check_own_contents:
if request.children_type == ProjectChildrenType.dataset:
contents = project_bll.calc_own_datasets(
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=request.search_hidden,
project_ids=project_ids,
filter_=request.include_stats_filter,
users=request.active_users,
user_active_project_ids=user_active_project_ids,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
if request.include_dataset_stats:
project_ids = {project["id"] for project in projects}
dataset_stats = project_bll.get_dataset_stats(
else:
contents = project_bll.calc_own_contents(
company=company_id,
project_ids=list(project_ids),
project_ids=project_ids,
filter_=_get_project_stats_filter(request)[0],
users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
call.result.data = {"projects": projects, **ret_params}
for project in projects:
project.update(**contents.get(project["id"], {}))
if request.include_stats:
if request.children_type == ProjectChildrenType.dataset:
stats, children = project_bll.get_project_dataset_stats(
company=company_id,
project_ids=project_ids,
include_children=request.stats_with_children,
filter_=request.include_stats_filter,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
else:
filter_, search_hidden = _get_project_stats_filter(request)
stats, children = project_bll.get_project_stats(
company=company_id,
project_ids=project_ids,
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
search_hidden=search_hidden,
filter_=filter_,
users=request.active_users,
selected_project_ids=selected_project_ids,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id, project_ids=project_ids, users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
call.result.data = {"projects": projects, **ret_params}
@endpoint("projects.get_all")
@@ -183,20 +237,19 @@ def get_all(call: APICall):
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
)
with TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}
@endpoint(
@@ -282,6 +335,7 @@ def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
def delete(call: APICall, company_id: str, request: DeleteRequest):
res, affected_projects = delete_project(
company=company_id,
user=call.identity.user,
project_id=request.project,
force=request.force,
delete_contents=request.delete_contents,

View File

@@ -1,3 +1,5 @@
from mongoengine import Q
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.queues import (
GetDefaultResp,
@@ -15,10 +17,12 @@ from apiserver.apimodels.queues import (
DeleteMetadataRequest,
GetNextTaskRequest,
GetByIdRequest,
GetAllRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
@@ -51,16 +55,33 @@ def get_by_id(call: APICall):
call.result.data_model = GetDefaultResp(id=queue.id, name=queue.name)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden queues search condition (unless specifically specified differently)
"""
hidden_tags = config.get("services.queues.hidden_tags", [])
if (
not hidden_tags
or data.get("search_hidden")
or data.get("id")
or data.get("name")
):
return Q()
return Q(system_tags__nin=hidden_tags)
@endpoint("queues.get_all_ex", min_version="2.4")
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company,
company_id=company,
query_dict=call.data,
max_task_entries=call.data.pop("max_task_entries", None),
query=_hidden_query(call.data),
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
@@ -69,14 +90,15 @@ def get_all_ex(call: APICall):
@endpoint("queues.get_all", min_version="2.4")
def get_all(call: APICall):
def get_all(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company,
company_id=company,
query_dict=call.data,
max_task_entries=call.data.pop("max_task_entries", None),
query=_hidden_query(call.data),
max_task_entries=request.max_task_entries,
ret_params=ret_params,
)
conform_output_tags(call, queues)
@@ -120,7 +142,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest):
queue_bll.delete(
company_id=company_id, queue_id=req_model.queue, force=req_model.force
company_id=company_id,
user_id=call.identity.user,
queue_id=req_model.queue,
force=req_model.force,
)
call.result.data = {"deleted": 1}
@@ -135,11 +160,13 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
entry = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=request.queue, task_id=request.task
)
if entry:
data = {"entry": entry.to_proper_dict()}
if req_model.get_task_info:
if request.get_task_info:
task = Task.objects(id=entry.task).first()
if task:
data["task_info"] = {"company": task.company, "user": task.user}

View File

@@ -0,0 +1,375 @@
import textwrap
from datetime import datetime
from itertools import chain
from typing import Sequence
from apiserver.apimodels.reports import (
CreateReportRequest,
UpdateReportRequest,
PublishReportRequest,
ArchiveReportRequest,
DeleteReportRequest,
MoveReportRequest,
GetTasksDataRequest,
EventsRequest,
GetAllRequest,
)
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL, ChangeStatusRequest
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
from apiserver.service_repo import APICall, endpoint
from apiserver.services.events import (
_get_task_or_model_index_companies,
event_bll,
_get_metrics_response,
_get_metric_variants_from_request,
_get_multitask_plots,
)
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
unprepare_from_saved,
)
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
update_fields = {
"name",
"tags",
"comment",
"report",
"report_assets",
}
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
task = TaskBLL.get_task_with_access(
task_id=task_id,
company_id=company_id,
only=only_fields,
requires_write_access=requires_write_access,
)
if task.type != TaskType.report:
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
return task
@endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task, company_id=company_id, only_fields=("status",),
)
partial_update_dict = {
field: value for field, value in call.data.items() if field in update_fields
}
if not partial_update_dict:
return UpdateResponse(updated=0)
allowed_for_published = set(partial_update_dict.keys()).issubset(
{"tags", "name", "comment"}
)
if task.status != TaskStatus.created and not allowed_for_published:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
now = datetime.utcnow()
more_updates = {"last_change": now, "last_changed_by": call.identity.user}
if not allowed_for_published:
more_updates["last_update"] = now
updated = task.update(upsert=False, **partial_update_dict, **more_updates)
if not updated:
return UpdateResponse(updated=0)
updated_tags = partial_update_dict.get("tags")
if updated_tags:
partial_update_dict["tags"] = sorted(updated_tags)
updated_report = partial_update_dict.get("report")
if updated_report:
partial_update_dict["report"] = textwrap.shorten(updated_report, width=100)
return UpdateResponse(updated=updated, fields=partial_update_dict)
def _ensure_reports_project(company: str, user: str, name: str):
name = name.strip("/")
_, _, basename = name.rpartition("/")
if basename != reports_project_name:
name = f"{name}/{reports_project_name}"
return project_bll.find_or_create(
user=user,
company=company,
project_name=name,
description="Reports project",
system_tags=[reports_tag, EntityVisibility.hidden.value],
)
@endpoint("reports.create")
def create_report(call: APICall, company_id: str, request: CreateReportRequest):
user_id = call.identity.user
project_id = request.project
if request.project:
project = Project.get_for_writing(
company=company_id, id=project_id, _only=("name",)
)
project_name = project.name
else:
project_name = ""
project_id = _ensure_reports_project(
company=company_id, user=user_id, name=project_name
)
task = task_bll.create(
company=company_id,
user=user_id,
fields=dict(
project=project_id,
name=request.name,
tags=request.tags,
comment=request.comment,
type=TaskType.report,
system_tags=[reports_tag, EntityVisibility.hidden.value],
),
)
task.save()
call.result.data = {"id": task.id, "project_id": project_id}
def _delete_reports_project_if_empty(project_id):
project = Project.objects(id=project_id).only("basename").first()
if (
project
and project.basename == reports_project_name
and Task.objects(project=project_id).count() == 0
):
project.delete()
@endpoint("reports.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllRequest):
call_data = call.data
call_data["type"] = TaskType.report
call_data["include_subprojects"] = True
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
def _get_task_metrics_from_request(
task_ids: Sequence[str], request: EventsRequest
) -> dict:
task_metrics = {}
for task in task_ids:
task_dict = {}
for mv in request.metrics:
task_dict[mv.metric] = mv.variants
task_metrics[task] = task_dict
return task_metrics
@endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
call_data = escape_execution_parameters(call)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
res = {"tasks": tasks, **ret_params}
if not (
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
):
return res
task_ids = [task["id"] for task in tasks]
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events(
companies={
t.id: t.company for t in chain.from_iterable(companies.values())
},
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
iter_count=request.debug_images.iters,
)
res["debug_images"] = [
r.to_struct() for r in _get_metrics_response(result.metric_events)
]
if request.plots:
res["plots"] = _get_multitask_plots(
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
)[0]
if request.scalar_metrics_iter_histogram:
res[
"scalar_metrics_iter_histogram"
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
companies=companies,
samples=request.scalar_metrics_iter_histogram.samples,
key=request.scalar_metrics_iter_histogram.key,
metric_variants=_get_metric_variants_from_request(
request.scalar_metrics_iter_histogram.metrics
),
)
call.result.data = res
@endpoint("reports.move")
def move(call: APICall, company_id: str, request: MoveReportRequest):
if not ("project" in call.data or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
)
user_id = call.identity.user
project_name = request.project_name
if not project_name:
if request.project:
project = Project.get_for_writing(
company=company_id, id=request.project, _only=("name",)
)
project_name = project.name
else:
project_name = ""
project_id = _ensure_reports_project(
company=company_id, user=user_id, name=project_name
)
project_bll.move_under_project(
entity_cls=Task,
user=call.identity.user,
company=company_id,
ids=[request.task],
project=project_id,
)
_delete_reports_project_if_empty(task.project)
return {"project_id": project_id}
@endpoint(
"reports.publish", response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
updates = ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=True,
status_reason="",
status_message=request.message,
user_id=call.identity.user,
).execute(published=datetime.utcnow())
call.result.data_model = UpdateResponse(**updates)
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
archived = task.update(
status_message=request.message,
status_reason="",
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
)
return {"archived": archived}
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
unarchived = task.update(
status_message=request.message,
status_reason="",
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
)
return {"unarchived": unarchived}
# @endpoint("reports.share")
# def share(call: APICall, company_id, request: ShareReportRequest):
# _assert_report(
# company_id=company_id, user_id=call.identity.user, task_id=request.task
# )
# call.result.data = {
# "changed": task_bll.share_task(
# company_id=company_id, task_ids=[request.task], share=request.share
# )
# }
@endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
)
if (
task.status != TaskStatus.created
and EntityVisibility.archived.value not in task.system_tags
and not request.force
):
raise errors.bad_request.TaskCannotBeDeleted(
"due to status, use force=True",
task=task.id,
expected=TaskStatus.created,
current=task.status,
)
task.delete()
_delete_reports_project_if_empty(task.project)
call.result.data = {"deleted": 1}
@endpoint("reports.get_tags")
def get_tags(call: APICall, company_id: str, _):
tags = Task.objects(company=company_id, type=TaskType.report).distinct(field="tags")
call.result.data = sort_tags_response({"tags": tags})

View File

@@ -64,11 +64,12 @@ from apiserver.apimodels.tasks import (
ResetBatchItem,
CompletedRequest,
CompletedResponse,
GetAllReq,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
@@ -81,7 +82,6 @@ from apiserver.bll.task.artifacts import (
Artifacts,
)
from apiserver.bll.task.hyperparams import HyperParams
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from apiserver.bll.task.param_utils import (
params_prepare_for_save,
params_unprepare_from_saved,
@@ -119,8 +119,8 @@ from apiserver.services.utils import (
DockerCmdBackwardsCompatibility,
escape_dict_field,
unescape_dict_field,
process_include_subprojects,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.partial_version import PartialVersion
@@ -135,11 +135,9 @@ queue_bll = QueueBLL()
org_bll = OrgBLL()
project_bll = ProjectBLL()
NonResponsiveTasksWatchdog.start()
def set_task_status_from_call(
request: UpdateRequest, company_id, new_status=None, **set_fields
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
) -> dict:
fields_resolver = SetFieldsResolver(set_fields)
task = TaskBLL.get_task_with_access(
@@ -174,6 +172,7 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute(**fields_resolver.get_fields(task))
@@ -207,17 +206,6 @@ def escape_execution_parameters(call: APICall) -> dict:
return call_data
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
@@ -228,22 +216,21 @@ def _hidden_query(data: dict) -> Q:
return Q(system_tags__ne=EntityVisibility.hidden.value)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
@endpoint("tasks.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllReq):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=request.allow_public,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@@ -254,10 +241,9 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@@ -269,16 +255,15 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all"):
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@@ -308,6 +293,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_id=call.identity.user,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
@@ -325,6 +311,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial(
stop_task,
company_id=company_id,
user_id=call.identity.user,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
@@ -346,7 +333,8 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(
req_model,
company_id,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
@@ -362,7 +350,8 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
res = StartedResponse(
**set_task_status_from_call(
req_model,
company_id,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
)
@@ -376,7 +365,12 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
)
def failed(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed)
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.failed,
)
)
@@ -385,7 +379,11 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
)
def close(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed)
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.closed)
)
@@ -490,12 +488,13 @@ def prepare_create_fields(
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_call"):
):
fields = prepare_create_fields(call, **kwargs)
task = task_bll.create(call, fields)
task = task_bll.create(
company=call.identity.company, user=call.identity.user, fields=fields
)
with TimingContext("code", "validate"):
task_bll.validate(task)
task_bll.validate(task)
return task, fields
@@ -528,7 +527,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
def create(call: APICall, company_id, req_model: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(), TimingContext("mongo", "save_task"):
with translate_errors_context():
task.save()
_update_cached_tags(company_id, project=task.project, fields=fields)
update_project_time(task.project)
@@ -596,7 +595,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
company_id=company_id,
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(last_change=datetime.utcnow()),
injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
),
)
if updated_count:
new_project = updated_fields.get("project", task.project)
@@ -629,7 +630,11 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
raise errors.bad_request.MissingTaskFields(
"Task has no script field", task=task.id
)
res = update_task(task, update_cmds=dict(script__requirements=requirements))
res = update_task(
task,
user_id=call.identity.user,
update_cmds=dict(script__requirements=requirements),
)
call.result.data_model = UpdateResponse(updated=res)
if res:
call.result.data_model.fields = {"script.requirements": requirements}
@@ -664,7 +669,9 @@ def update_batch(call: APICall, company_id, _):
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
partial_update_dict.update(last_change=now)
partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
)
@@ -711,7 +718,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_and_validate"):
):
fields = prepare_create_fields(
call, valid_fields=edit_fields, output=task.output, previous_task=task
)
@@ -728,7 +735,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
d.update(value)
fields[key] = d
task_bll.validate(task_bll.create(call, fields))
task_bll.validate(
task_bll.create(
company=call.identity.company, user=call.identity.user, fields=fields
)
)
# make sure field names do not end in mongoengine comparison operators
fixed_fields = {
@@ -737,7 +748,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
}
if fixed_fields:
now = datetime.utcnow()
last_change = dict(last_change=now)
last_change = dict(last_change=now, last_changed_by=call.identity.user)
if not set(fields).issubset(Task.user_set_allowed()):
last_change.update(last_update=now)
fields.update(**last_change)
@@ -774,6 +785,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
@@ -787,6 +799,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
@@ -831,6 +844,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
user_id=call.identity.user,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
@@ -846,6 +860,7 @@ def delete_configuration(
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
user_id=call.identity.user,
task_id=request.task,
configuration=request.configuration,
force=request.force,
@@ -862,12 +877,18 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
)
if request.verify_watched_queue:
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
res["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
call.result.data_model = EnqueueResponse(queued=queued, **res)
@@ -881,6 +902,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial(
enqueue_task,
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -889,12 +911,20 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
),
ids=request.ids,
)
extra = {}
if request.verify_watched_queue and results:
_id, (queued, res) = results[0]
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
call.result.data_model = EnqueueManyResponse(
succeeded=[
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
for _id, (queued, res) in results
],
failed=failures,
**extra,
)
@@ -907,6 +937,7 @@ def dequeue(call: APICall, company_id, request: UpdateRequest):
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
)
@@ -923,6 +954,7 @@ def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
dequeue_task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -944,10 +976,12 @@ def reset(call: APICall, company_id, request: ResetRequest):
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
)
res = ResetResponse(**updates, dequeued=dequeued)
@@ -970,10 +1004,12 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
func=partial(
reset_task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
clear_all=request.clear_all,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
@@ -1014,6 +1050,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
for task in tasks:
archived += archive_task(
company_id=company_id,
user_id=call.identity.user,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -1032,6 +1069,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
archive_task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1053,6 +1091,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
unarchive_task,
company_id=company_id,
user_id=call.identity.user,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1071,12 +1110,14 @@ def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
)
if deleted:
if request.move_to_trash:
@@ -1091,12 +1132,14 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
),
ids=request.ids,
)
@@ -1127,6 +1170,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
@@ -1145,6 +1189,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial(
publish_task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
@@ -1171,7 +1216,8 @@ def completed(call: APICall, company_id, request: CompletedRequest):
res = CompletedResponse(
**set_task_status_from_call(
request,
company_id,
company_id=company_id,
user_id=call.identity.user,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
@@ -1181,6 +1227,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
@@ -1212,6 +1259,7 @@ def add_or_update_artifacts(
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
artifacts=request.artifacts,
force=True,
@@ -1228,6 +1276,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
@@ -1251,7 +1300,7 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
@endpoint("tasks.move", request_data_model=MoveRequest)
def move(call: APICall, company_id: str, request: MoveRequest):
if not (request.project or request.project_name):
if not ("project" in call.data or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)
@@ -1295,7 +1344,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
delete_names = {
@@ -1308,5 +1357,5 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
if names
}
updated = task.update(last_change=datetime.utcnow(), **commands,)
updated = task.update(last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,)
return {"updated": updated}

View File

@@ -12,7 +12,7 @@ from apiserver.bll.project import ProjectBLL
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import Role, User as AuthUser
from apiserver.database.model.auth import Role
from apiserver.database.model.company import Company
from apiserver.database.model.user import User
from apiserver.database.utils import parse_from_call
@@ -114,12 +114,6 @@ def get_current_user(call: APICall, company_id, _):
user = res[0]
user["role"] = call.identity.role
auth_user: AuthUser = AuthUser.objects(id=user_id, company=company_id).first()
if not auth_user:
raise errors.bad_request.InvalidUser("failed loading user")
user["created"] = auth_user.created
resp = {
"user": user,
"getting_started": config.get("apiserver.getting_started_info", None),

View File

@@ -3,6 +3,7 @@ from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.organization import Filter
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags
@@ -12,6 +13,17 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.utilities.partial_version import PartialVersion
def process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_:
return {}

View File

@@ -42,7 +42,10 @@ worker_bll = WorkerBLL()
def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection(
company_id, request.last_seen, tags=request.tags
company_id,
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
)
)
@@ -53,6 +56,9 @@ def register(call: APICall, company_id, request: RegisterRequest):
timeout = request.timeout
queues = request.queues
if not timeout:
timeout = config.get("apiserver.workers.default_timeout", 10 * 60)
if not timeout or timeout <= 0:
raise bad_request.WorkerRegistrationFailed(
"invalid timeout", timeout=timeout, worker=worker
@@ -66,6 +72,7 @@ def register(call: APICall, company_id, request: RegisterRequest):
queues=queues,
timeout=timeout,
tags=request.tags,
system_tags=request.system_tags,
)
@@ -84,6 +91,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
ip=call.real_ip,
report=request,
tags=request.tags,
system_tags=request.system_tags,
)

View File

@@ -60,12 +60,12 @@ class TestService(TestCase, TestServiceInterface):
def update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str:
def create_temp(self, service, *, client=None, delete_params=None, object_name="", **kwargs) -> str:
return self._create_temp_helper(
service=service,
create_endpoint="create",
delete_endpoint="delete",
object_name=service.rstrip("s"),
object_name=object_name or service.rstrip("s"),
create_params=kwargs,
client=client,
delete_params=delete_params,

View File

@@ -7,9 +7,6 @@ class TestBatchOperations(TestService):
comment = "this is a comment"
delete_params = dict(can_fail=True, force=True)
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_tasks(self):
tasks = [self._temp_task() for _ in range(2)]
models = [

View File

@@ -30,6 +30,11 @@ class TestMoveUnderProject(TestService):
self.assertEqual(p2_name, projects[0].name)
self.api.projects.delete(project=project2, force=True)
# move to the root project
self.assertEqual(None, self.api.tasks.move(ids=[task], project=None).project_id)
tasks = self.api.tasks.get_all_ex(id=[task]).tasks
self.assertEqual(None, tasks[0].get("project"))
# model move into existing project referenced by name
model = self._temp_model()
self.api.models.move(ids=[model], project_name=self.entity_name)

View File

@@ -0,0 +1,50 @@
from typing import Tuple
from apiserver.tests.automated import TestService
class TestPipelines(TestService):
def test_start_pipeline(self):
queue = self.api.queues.get_default().id
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
args = [{"name": "hello", "value": "test"}]
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
pipeline_task = res.pipeline
try:
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
finally:
self.api.tasks.delete(task=pipeline_task, force=True)
def _temp_project_and_task(self, name) -> Tuple[str, str]:
project = self.create_temp(
"projects", name=name, description="test", delete_params=dict(force=True),
)
return (
project,
self.create_temp(
"tasks",
name=name,
type="testing",
input=dict(view=dict()),
project=project,
system_tags=["pipeline"],
),
)

View File

@@ -21,7 +21,7 @@ class TestQueues(TestService):
def test_queue_metrics(self):
queue_id = self._temp_queue("TestTempQueue")
task1 = self._create_temp_queued_task("temp task 1", queue_id)
self._create_temp_queued_task("temp task 1", queue_id)
time.sleep(1)
task2 = self._create_temp_queued_task("temp task 2", queue_id)
self.api.queues.get_next_task(queue=queue_id)
@@ -36,6 +36,27 @@ class TestQueues(TestService):
)
self.assertMetricQueues(res["queues"], queue_id)
def test_hidden_queues(self):
hidden_name = "TestHiddenQueue"
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])
non_hidden_queue = self._temp_queue("TestNonHiddenQueue")
queues = self.api.queues.get_all_ex().queues
ids = {q.id for q in queues}
self.assertFalse(hidden_queue in ids)
self.assertTrue(non_hidden_queue in ids)
queues = self.api.queues.get_all_ex(search_hidden=True).queues
ids = {q.id for q in queues}
self.assertTrue(hidden_queue in ids)
self.assertTrue(non_hidden_queue in ids)
queues = self.api.queues.get_all_ex(name=f"^{hidden_name}$").queues
self.assertEqual(hidden_queue, queues[0].id)
queues = self.api.queues.get_all_ex(id=[hidden_queue]).queues
self.assertEqual(hidden_queue, queues[0].id)
def test_reset_task(self):
queue = self._temp_queue("TestTempQueue")
task = self._temp_task("TempTask", is_development=True)
@@ -60,6 +81,19 @@ class TestQueues(TestService):
self.assertQueueTasks(res.queue, [task])
self.assertTaskTags(task, system_tags=[])
def test_dequeue_from_deleted_queue(self):
queue = self._temp_queue("TestTempQueue")
task_name = "TempDevTask"
task = self._temp_task(task_name)
self.api.tasks.enqueue(task=task, queue=queue)
res = self.api.tasks.get_by_id(task=task)
self.assertEqual(res.task.status, "queued")
self.api.queues.delete(queue=queue, force=True)
res = self.api.tasks.get_by_id(task=task)
self.assertEqual(res.task.status, "created")
def test_max_queue_entries(self):
queue = self._temp_queue("TestTempQueue")
tasks = [
@@ -193,8 +227,8 @@ class TestQueues(TestService):
sorted(queue.workers, key=sort_key), sorted(workers, key=sort_key)
)
def _temp_queue(self, queue_name, tags=None):
return self.create_temp("queues", name=queue_name, tags=tags)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)
def _temp_task(self, task_name, is_testing=False, is_development=False):
task_input = dict(

View File

@@ -0,0 +1,208 @@
import re
from boltons.iterutils import first
from apiserver.apierrors import errors
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService
from apiserver.utilities.dicts import nested_get
class TestReports(TestService):
def _delete_project(self, name):
existing_project = first(
self.api.projects.get_all_ex(
name=f"^{re.escape(name)}$", search_hidden=True
).projects
)
if existing_project:
self.api.projects.delete(
project=existing_project.id, force=True, delete_contents=True
)
def test_create_update_move(self):
task_name = "Rep1"
comment = "My report"
tags = ["hello"]
# report creates a hidden task under hidden .reports subproject
self._delete_project(".reports")
task_id = self._temp_report(name=task_name, comment=comment, tags=tags)
task = self.api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(task.name, task_name)
self.assertEqual(task.comment, comment)
self.assertEqual(set(task.tags), set(tags))
self.assertEqual(task.type, "report")
self.assertEqual(set(task.system_tags), {"hidden", "reports"})
projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects
self.assertEqual(len(projects), 0)
project = self.api.projects.get_all_ex(
name=r"^\.reports$", search_hidden=True
).projects[0]
self.assertEqual(project.id, task.project.id)
self.assertEqual(set(project.system_tags), {"hidden", "reports"})
ret = self.api.reports.get_tags()
self.assertEqual(ret.tags, sorted(tags))
# update is working on draft reports
new_comment = "My new comment"
res = self.api.reports.update(
task=task_id,
comment=new_comment,
tags=[],
report_assets=["file://test.jpg"],
)
self.assertEqual(res.updated, 1)
task = self.api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(task.name, task_name)
self.assertEqual(task.comment, new_comment)
self.assertEqual(task.tags, [])
ret = self.api.reports.get_tags()
self.assertEqual(ret.tags, [])
self.assertEqual(task.report_assets, ["file://test.jpg"])
self.api.reports.publish(task=task_id)
with self.api.raises(errors.bad_request.InvalidTaskStatus):
self.api.reports.update(task=task_id, report="New report text")
# update on tags or rename can be done for published report too
self.api.reports.update(
task=task_id, name="new name", tags=["test"], comment="Yet another comment"
)
task = self.api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(task.tags, ["test"])
self.assertEqual(task.name, "new name")
self.assertEqual(task.comment, "Yet another comment")
# move under another project autodeletes the empty project
new_project_name = "Reports Test"
self._delete_project(new_project_name)
task2_id = self._temp_report(name="Rep2")
new_project_id = self.api.reports.move(
task=task_id, project_name=new_project_name
).project_id
new_project = self.api.projects.get_all_ex(id=[new_project_id]).projects[0]
self.assertEqual(new_project.name, f"{new_project_name}/.reports")
self.assertEqual(set(new_project.system_tags), {"hidden", "reports"})
self.assertEqual(len(self.api.projects.get_all_ex(id=project.id).projects), 1)
self.api.reports.move(task=task2_id, project=new_project_id)
self.assertEqual(len(self.api.projects.get_all_ex(id=project.id).projects), 0)
tasks = self.api.tasks.get_all_ex(
project=new_project_id, search_hidden=True
).tasks
self.assertTrue({task_id, task2_id}.issubset({t.id for t in tasks}))
project_id = self.api.reports.move(task=task2_id, project=None).project_id
project = self.api.projects.get_all_ex(id=[project_id]).projects[0]
self.assertEqual(project.get("parent"), None)
self.assertEqual(project.name, ".reports")
def test_reports_search(self):
report_task = self._temp_report(name="Rep1")
non_report_task = self._temp_task(name="hello")
res = self.api.reports.get_all_ex(
_any_={"pattern": "hello", "fields": ["name", "id", "tags", "report"]}
).tasks
self.assertEqual(len(res), 0)
self.api.reports.update(task=report_task, report="hello world")
res = self.api.reports.get_all_ex(
_any_={"pattern": "hello", "fields": ["name", "id", "tags", "report"]}
).tasks
self.assertEqual(len(res), 1)
self.assertEqual(res[0].id, report_task)
def test_reports_task_data(self):
report_task = self._temp_report(name="Rep1")
non_report_task = self._temp_task(name="hello")
debug_image_events = [
self._create_task_event(
task=non_report_task,
type_="training_debug_image",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
url=f"{m}_{v}",
)
for m in range(2)
for v in range(2)
]
plot_events = [
self._create_task_event(
task=non_report_task,
type_="plot",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
plot_str=f"Hello plot",
)
for m in range(2)
for v in range(2)
]
self.send_batch([*debug_image_events, *plot_events])
res = self.api.reports.get_task_data(
id=[non_report_task], only_fields=["name"],
)
self.assertEqual(len(res.tasks), 1)
self.assertEqual(res.tasks[0].id, non_report_task)
self.assertFalse(any(field in res for field in ("plots", "debug_images")))
res = self.api.reports.get_task_data(
id=[non_report_task],
only_fields=["name"],
debug_images={"metrics": []},
plots={"metrics": [{"metric": "Metric_1"}]},
)
self.assertEqual(len(res.debug_images), 1)
task_events = res.debug_images[0]
self.assertEqual(task_events.task, non_report_task)
self.assertEqual(len(task_events.iterations), 1)
self.assertEqual(len(task_events.iterations[0].events), 4)
self.assertEqual(len(res.plots), 1)
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
tasks = nested_get(res.plots, (m, v))
self.assertEqual(len(tasks), 1)
task_plots = tasks[non_report_task]
self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"]
self.assertEqual(iter_plots.name, "hello")
self.assertEqual(len(iter_plots.plots), 1)
ev = iter_plots.plots[0]
self.assertEqual(ev["metric"], m)
self.assertEqual(ev["variant"], v)
self.assertEqual(ev["task"], non_report_task)
self.assertEqual(ev["iter"], 1)
@staticmethod
def _create_task_event(type_, task, iteration, **kwargs):
return {
"worker": "test",
"type": type_,
"task": task,
"iter": iteration,
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
**kwargs,
}
def _temp_report(self, name, **kwargs):
return self.create_temp(
"reports",
name=name,
object_name="task",
delete_params={"force": True},
**kwargs,
)
def _temp_task(self, name, **kwargs):
return self.create_temp(
"tasks",
name=name,
type="training",
delete_params={"force": True},
**kwargs,
)
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data

View File

@@ -14,17 +14,91 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count(datasets={"system_tags": ["dataset"]})
res = self.api.organization.get_entities_count(
datasets={"system_tags": ["dataset"]}
)
self.assertEqual(res.datasets, 1)
task = self._temp_task(project=project)
data = self.api.projects.get_all_ex(id=[project], include_dataset_stats=True).projects[0]
data = self.api.projects.get_all_ex(
id=[project], include_dataset_stats=True
).projects[0]
self.assertIsNone(data.dataset_stats)
self.api.tasks.edit(task=task, runtime={"ds_file_count": 2, "ds_total_size": 1000})
data = self.api.projects.get_all_ex(id=[project], include_dataset_stats=True).projects[0]
self.api.tasks.edit(
task=task, runtime={"ds_file_count": 2, "ds_total_size": 1000}
)
data = self.api.projects.get_all_ex(
id=[project], include_dataset_stats=True
).projects[0]
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
def test_query_children(self):
test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name)
dataset_tags = ["hello", "world"]
dataset_project = self._temp_project(
name=f"{test_root_name}/Project1/Dataset",
system_tags=["dataset"],
tags=dataset_tags,
)
self._temp_task(
name="dataset task",
type="data_processing",
system_tags=["dataset"],
project=dataset_project,
)
self._temp_task(name="regular task", project=dataset_project)
pipeline_project = self._temp_project(
name=f"{test_root_name}/Project2/Pipeline", system_tags=["pipeline"]
)
self._temp_task(
name="pipeline task",
type="controller",
system_tags=["pipeline"],
project=pipeline_project,
)
self._temp_task(name="regular task", project=pipeline_project)
report_project = self._temp_project(name=f"{test_root_name}/Project3")
self._temp_report(name="test report", project=report_project)
self._temp_task(name="regular task", project=report_project)
projects = self.api.projects.get_all_ex(
parent=[test_root], shallow_search=True, include_stats=True
).projects
self.assertEqual(
{p.basename for p in projects}, {f"Project{idx+1}" for idx in range(3)}
)
for p in projects:
self.assertEqual(
p.stats.active.total_tasks,
2
if p.basename in ("Project1", "Project2")
else 1
)
for i, type_ in enumerate(("dataset", "pipeline", "report")):
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type=type_,
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual({p.basename for p in projects}, {f"Project{i+1}"})
p = projects[0]
if type_ in ("dataset",):
self.assertEqual(p.own_datasets, 1)
self.assertIsNone(p.get("own_tasks"))
self.assertEqual(p.stats.datasets.count, 1)
self.assertEqual(p.stats.datasets.tags, dataset_tags)
else:
self.assertEqual(p.own_tasks, 0)
self.assertIsNone(p.get("own_datasets"))
self.assertEqual(
p.stats.active.total_tasks, 1 if p.basename != "Project4" else 0
)
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
@@ -52,6 +126,10 @@ class TestSubProjects(TestService):
self.assertEqual(res.types, [])
res = self.api.projects.get_task_parents(projects=[project])
self.assertEqual(res.parents, [])
res = self.api.organization.get_entities_count(
projects={"id": [project]}, active_users=[user]
)
self.assertEqual(res.projects, 0)
# test aggregations with non-empty subprojects
task1 = self._temp_task(project=child)
@@ -64,7 +142,9 @@ class TestSubProjects(TestService):
res = self.api.projects.get_all_ex(id=[project], include_stats=True)
self._assert_ids(res.projects, [project])
self.assertEqual(res.projects[0].stats.active.total_tasks, 3)
res = self.api.projects.get_all_ex(id=[project], active_users=[user], include_stats=True)
res = self.api.projects.get_all_ex(
id=[project], active_users=[user], include_stats=True
)
self._assert_ids(res.projects, [project])
self.assertEqual(res.projects[0].stats.active.total_tasks, 2)
res = self.api.projects.get_task_parents(projects=[project])
@@ -73,6 +153,10 @@ class TestSubProjects(TestService):
self.assertEqual(res.frameworks, [framework])
res = self.api.tasks.get_types(projects=[project])
self.assertEqual(res.types, ["testing"])
res = self.api.organization.get_entities_count(
projects={"id": [project]}, active_users=[user]
)
self.assertEqual(res.projects, 1)
def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]):
self.assertEqual([a["id"] for a in actual], expected)
@@ -280,12 +364,21 @@ class TestSubProjects(TestService):
**kwargs,
)
def _temp_task(self, client=None, **kwargs):
def _temp_report(self, name, **kwargs):
return self.create_temp(
"reports",
name=name,
object_name="task",
delete_params=self.delete_params,
**kwargs,
)
def _temp_task(self, client=None, name=None, type=None, **kwargs):
return self.create_temp(
"tasks",
delete_params=self.delete_params,
type="testing",
name=db_id(),
type=type or "testing",
name=name or db_id(),
input=dict(view=dict()),
client=client,
**kwargs,

View File

@@ -6,17 +6,26 @@ from typing import Sequence, Optional, Tuple
from boltons.iterutils import first
from apiserver.apierrors import errors
from apiserver.es_factory import es_factory
from apiserver.apierrors.errors.bad_request import EventsNotAdded
from apiserver.tests.automated import TestService
class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
)
return self.create_temp("tasks", **task_input)
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
)
def _temp_model(self, name="test model events", **kwargs):
self.update_missing(kwargs, name=name, uri="file:///a/b", labels={})
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
@staticmethod
def _create_task_event(type_, task, iteration, **kwargs):
@@ -97,9 +106,7 @@ class TestTaskEvents(TestService):
res.variants[variant]["iter"],
[x or special_iteration for x in range(iter_count)],
)
self.assertEqual(
res.variants[variant]["y"], list(range(iter_count))
)
self.assertEqual(res.variants[variant]["y"], list(range(iter_count)))
# but not in the histogram
data = self.api.events.scalar_metrics_iter_histogram(task=task)
@@ -133,8 +140,7 @@ class TestTaskEvents(TestService):
task=task, batch_size=100, metric=metric_param, count_total=True
)
self.assertEqual(
res.variants[variant]["y"],
[y or new_value for y in range(iter_count)],
res.variants[variant]["y"], [y or new_value for y in range(iter_count)],
)
task_data = self.api.tasks.get_by_id(task=task).task
@@ -170,7 +176,49 @@ class TestTaskEvents(TestService):
metric_data = first(first(task_data.last_metrics.values()).values())
self.assertEqual(iter_count - 1, metric_data.value)
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
res = self.api.events.get_task_latest_scalar_values(task=task)
self.assertEqual(iter_count - 1, res.last_iter)
def test_model_events(self):
model = self._temp_model(ready=False)
# task log events are not allowed
log_event = self._create_task_event(
"log",
task=model,
iteration=0,
msg=f"This is a log message",
model_event=True,
)
with self.api.raises(errors.bad_request.EventsNotAdded):
self.send(log_event)
events = [
{
**self._create_task_event("training_stats_scalar", model, iteration),
"metric": f"Metric{metric_idx}",
"variant": f"Variant{variant_idx}",
"value": iteration,
"model_event": True,
}
for iteration in range(2)
for metric_idx in range(5)
for variant_idx in range(5)
]
self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram(
task=model, model_events=True
)
self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)])
metric_data = data.Metric0
self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)])
variant_data = metric_data.Variant0
self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0])
def test_error_events(self):
task = self._temp_task()
@@ -555,7 +603,8 @@ class TestTaskEvents(TestService):
return data
def send(self, event):
self.api.send("events.add", event)
_, data = self.api.send("events.add", event)
return data
if __name__ == "__main__":

View File

@@ -26,57 +26,55 @@ class TestTaskPlots(TestService):
def test_get_plot_sample(self):
task = self._temp_task()
metric = "Metric1"
variant = "Variant1"
variants = ["Variant1", "Variant2"]
# test empty
res = self.api.events.get_plot_sample(
task=task, metric=metric, variant=variant
)
res = self.api.events.get_plot_sample(task=task, metric=metric)
self.assertEqual(res.min_iteration, None)
self.assertEqual(res.max_iteration, None)
self.assertEqual(res.event, None)
self.assertEqual(res.events, [])
# test existing events
iterations = 10
iterations = 5
events = [
self._create_task_event(
task=task,
iteration=n,
iteration=n // len(variants),
metric=metric,
variant=variant,
variant=variants[n % len(variants)],
plot_str=f"Test plot str {n}",
)
for n in range(iterations)
for n in range(iterations * len(variants))
]
self.send_batch(events)
# if iteration is not specified then return the event from the last one
res = self.api.events.get_plot_sample(
task=task, metric=metric, variant=variant
)
self._assertEqualEvent(res.event, events[-1])
res = self.api.events.get_plot_sample(task=task, metric=metric)
self._assertEqualEvents(res.events, events[-len(variants) :])
self.assertEqual(res.max_iteration, iterations - 1)
self.assertEqual(res.min_iteration, 0)
self.assertTrue(res.scroll_id)
# else from the specific iteration
iteration = 8
iteration = 3
res = self.api.events.get_plot_sample(
task=task,
metric=metric,
variant=variant,
iteration=iteration,
scroll_id=res.scroll_id,
task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
)
self._assertEqualEvents(
res.events,
events[iteration * len(variants) : (iteration + 1) * len(variants)],
)
self._assertEqualEvent(res.event, events[iteration])
def test_next_plot_sample(self):
task = self._temp_task()
metric1 = "Metric1"
variant1 = "Variant1"
metric2 = "Metric2"
variant2 = "Variant2"
metrics = [(metric1, variant1), (metric2, variant2)]
metrics = [
(metric1, "variant1"),
(metric1, "variant2"),
(metric2, "variant3"),
(metric2, "variant4"),
]
# test existing events
events = [
self._create_task_event(
@@ -93,57 +91,72 @@ class TestTaskPlots(TestService):
# single metric navigation
# init scroll
res = self.api.events.get_plot_sample(
task=task, metric=metric1, variant=variant1
)
self._assertEqualEvent(res.event, events[-2])
res = self.api.events.get_plot_sample(task=task, metric=metric1)
self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self.assertEqual(res.event, None)
self.assertEqual(res.events, [])
# navigate backwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-4])
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, None)
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-8:-6])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, [])
# all metrics navigation
# init scroll
res = self.api.events.get_plot_sample(
task=task, metric=metric1, variant=variant1, navigate_current_metric=False
task=task, metric=metric1, navigate_current_metric=False
)
self._assertEqualEvent(res.event, events[-2])
self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self._assertEqualEvent(res.event, events[-1])
self._assertEqualEvents(res.events, events[-2:])
# navigate backwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-2])
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-3])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-4:-2])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-6:-4])
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
if ev2 is None:
self.assertIsNone(ev1)
return
self.assertIsNotNone(ev1)
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
self.assertEqual(ev1[field], ev2[field])
# next_iteration
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True
)
self._assertEqualEvents(res.events, [])
res = self.api.events.next_plot_sample(
task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
)
self._assertEqualEvents(res.events, events[-4:-2])
self.assertTrue(all(ev.iter == 1 for ev in res.events))
res = self.api.events.next_plot_sample(
task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
)
self._assertEqualEvents(res.events, [])
def _assertEqualEvents(
self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
):
self.assertEqual(len(ev_source), len(ev_target))
def compare_event(ev1, ev2):
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
self.assertEqual(ev1[field], ev2[field])
for e1, e2 in zip(ev_source, ev_target):
compare_event(e1, e2)
def test_task_plots(self):
task = self._temp_task()
@@ -223,12 +236,15 @@ class TestTaskPlots(TestService):
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
expected_variants = set(
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
)
for metric_data in res.metrics:
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set((e.metric, e.variant) for e in it_data.events), expected_variants
set((e.metric, e.variant) for e in it_data.events),
expected_variants,
)
return res.scroll_id
@@ -266,7 +282,7 @@ class TestTaskPlots(TestService):
task=task,
metric=metric,
iterations=iterations,
variants=len(variants)
variants=len(variants),
)
# test forward navigation

Some files were not shown because too many files have changed in this diff Show More