mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
89 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8e4d9eeac | ||
|
|
b51aa5c29b | ||
|
|
e7c9daa42b | ||
|
|
7357654249 | ||
|
|
a6f671b46a | ||
|
|
17a8b440bd | ||
|
|
eb2b9cbd9a | ||
|
|
797e503e67 | ||
|
|
30cfdac8f2 | ||
|
|
24bb87aaee | ||
|
|
dd49ba180a | ||
|
|
bda903d0d8 | ||
|
|
9739eb2d5a | ||
|
|
cfbb37238f | ||
|
|
6664c6237e | ||
|
|
74200a24bd | ||
|
|
2fb9288a6c | ||
|
|
5d014d81af | ||
|
|
3a2675abe1 | ||
|
|
f0d68b1ce9 | ||
|
|
15db9cdaef | ||
|
|
a45d47f5d7 | ||
|
|
b1a50c1370 | ||
|
|
22a2a02760 | ||
|
|
ab798e4170 | ||
|
|
f09ac672d2 | ||
|
|
2149b76f63 | ||
|
|
d96420aa67 | ||
|
|
ed6c7b7bcb | ||
|
|
a392bc0bd7 | ||
|
|
7e97ec5555 | ||
|
|
9c41124b81 | ||
|
|
14ff639bb0 | ||
|
|
e66257761a | ||
|
|
0ffde24dc2 | ||
|
|
d4fdcd9b32 | ||
|
|
18570bfccb | ||
|
|
54ce6c34c6 | ||
|
|
ae4c33fa0e | ||
|
|
c7cd949fd0 | ||
|
|
1ce4058157 | ||
|
|
7b6f24b24d | ||
|
|
d03a931d84 | ||
|
|
5cc7199661 | ||
|
|
6537e9ef69 | ||
|
|
930aaff791 | ||
|
|
1999fb2479 | ||
|
|
9db14cc31d | ||
|
|
e3cc689528 | ||
|
|
9e0adc77dd | ||
|
|
58d9a64537 | ||
|
|
d397d2ae20 | ||
|
|
2d711e1500 | ||
|
|
97992b0d9e | ||
|
|
bc23f1b0cf | ||
|
|
6b3eff1426 | ||
|
|
caaf801cd0 | ||
|
|
c23e8a90d0 | ||
|
|
fa5b28ca0e | ||
|
|
bfb55a9463 | ||
|
|
37e485e1f2 | ||
|
|
3451ff441f | ||
|
|
53c9b5525e | ||
|
|
e5230edac3 | ||
|
|
a54dd8030c | ||
|
|
482a5c34bc | ||
|
|
ee2a72c70f | ||
|
|
a0d8aaf3b9 | ||
|
|
de1f823213 | ||
|
|
0c9e2f92ee | ||
|
|
6c49e96ff0 | ||
|
|
81e3fc6577 | ||
|
|
e6dc4b7557 | ||
|
|
238a47a197 | ||
|
|
04e7076628 | ||
|
|
0531612bf4 | ||
|
|
3ae410a1e9 | ||
|
|
98ed3075dd | ||
|
|
b871bf4224 | ||
|
|
8d4c02fc3c | ||
|
|
b986980c75 | ||
|
|
a4fa567be2 | ||
|
|
ddb91f226a | ||
|
|
7772f47773 | ||
|
|
9c118d14e0 | ||
|
|
efd56e085e | ||
|
|
4dff163af4 | ||
|
|
242a78a0fe | ||
|
|
78989fea91 |
@@ -50,6 +50,9 @@
|
||||
130: ["task_not_found", "task not found"]
|
||||
131: ["events_not_added", "events not added"]
|
||||
|
||||
# Reports
|
||||
150: ["operation_supported_on_reports_only", "passed task is not report"]
|
||||
|
||||
# Models
|
||||
200: ["model_error", "general task error"]
|
||||
201: ["invalid_model_id", "invalid model id"]
|
||||
|
||||
@@ -61,12 +61,6 @@ class ListField(fields.ListField):
|
||||
item.validate()
|
||||
|
||||
|
||||
# since there is no distinction between None and empty DictField
|
||||
# this value can be used as sentinel in order to distinguish
|
||||
# between not set and empty DictField
|
||||
DictFieldNotSet = {}
|
||||
|
||||
|
||||
class DictField(fields.BaseField):
|
||||
types = (dict,)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class MetricVariants(Base):
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@@ -35,11 +36,12 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
Length(
|
||||
minimum_value=1,
|
||||
maximum_value=config.get(
|
||||
"services.tasks.multi_task_histogram_limit", 10
|
||||
"services.tasks.multi_task_histogram_limit", 100
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
@@ -56,25 +58,36 @@ class MetricEventsRequest(Base):
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
|
||||
|
||||
class TaskMetricVariant(Base):
|
||||
class GetVariantSampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
|
||||
|
||||
class GetHistorySampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetMetricSamplesRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextHistorySampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
next_iteration: bool = BoolField(default=False)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
@@ -93,6 +106,7 @@ class TaskEventsRequest(TaskEventsRequestBase):
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
@@ -108,6 +122,7 @@ class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
@@ -129,6 +144,7 @@ class MultiTasksRequestBase(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
@@ -145,6 +161,7 @@ class TaskPlotsRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
|
||||
@@ -79,3 +79,4 @@ class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
|
||||
class ModelsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
@@ -19,3 +19,7 @@ class EntitiesCountRequest(models.Base):
|
||||
models = DictField()
|
||||
pipelines = DictField()
|
||||
datasets = DictField()
|
||||
reports = DictField()
|
||||
active_users = fields.ListField(str)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
@@ -56,14 +58,22 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectChildrenType(Enum):
|
||||
pipeline = "pipeline"
|
||||
report = "report"
|
||||
dataset = "dataset"
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_dataset_stats = fields.BoolField(default=False)
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
non_public = fields.BoolField(default=False) # legacy, use allow_public instead
|
||||
active_users = fields.ListField(str)
|
||||
check_own_contents = fields.BoolField(default=False)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
children_type = ActualEnumField(ProjectChildrenType)
|
||||
|
||||
@@ -30,9 +30,15 @@ class GetByIdRequest(QueueRequest):
|
||||
max_task_entries = IntField()
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
max_task_entries = IntField()
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
task = StringField()
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
|
||||
72
apiserver/apimodels/reports.py
Normal file
72
apiserver/apimodels/reports.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, ListField, BoolField, EmbeddedField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels.events import MetricVariants, HistogramRequestBase
|
||||
|
||||
|
||||
class UpdateReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
name = StringField(nullable=True, validators=Length(minimum_value=3))
|
||||
tags = ListField(items_types=[str])
|
||||
comment = StringField()
|
||||
report = StringField()
|
||||
report_assets = ListField(items_types=[str])
|
||||
|
||||
|
||||
class CreateReportRequest(Base):
|
||||
name = StringField(required=True, validators=Length(minimum_value=3))
|
||||
tags = ListField(items_types=[str])
|
||||
comment = StringField()
|
||||
report = StringField()
|
||||
project = StringField()
|
||||
report_assets = ListField(items_types=[str])
|
||||
|
||||
|
||||
class PublishReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
message = StringField(default="")
|
||||
|
||||
|
||||
class ArchiveReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
message = StringField(default="")
|
||||
|
||||
|
||||
class ShareReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
share = BoolField(default=True)
|
||||
|
||||
|
||||
class DeleteReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class MoveReportRequest(Base):
|
||||
task = StringField(required=True)
|
||||
project = StringField()
|
||||
project_name = StringField()
|
||||
|
||||
|
||||
class EventsRequest(Base):
|
||||
iters = IntField(default=1, validators=validators.Min(1))
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogram(HistogramRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class GetTasksDataRequest(Base):
|
||||
debug_images: EventsRequest = EmbeddedField(EventsRequest)
|
||||
plots: EventsRequest = EmbeddedField(EventsRequest)
|
||||
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
|
||||
allow_public = BoolField(default=True)
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
allow_public = BoolField(default=True)
|
||||
@@ -42,6 +42,7 @@ class StartedResponse(UpdateResponse):
|
||||
|
||||
class EnqueueResponse(UpdateResponse):
|
||||
queued = IntField()
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class EnqueueBatchItem(UpdateBatchItem):
|
||||
@@ -50,6 +51,7 @@ class EnqueueBatchItem(UpdateBatchItem):
|
||||
|
||||
class EnqueueManyResponse(BatchResponse):
|
||||
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
|
||||
queue_watched = BoolField()
|
||||
|
||||
|
||||
class DequeueResponse(UpdateResponse):
|
||||
@@ -97,12 +99,14 @@ class UpdateRequest(TaskUpdateRequest):
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
@@ -180,6 +184,7 @@ class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
@@ -273,6 +278,7 @@ class EnqueueManyRequest(TaskBatchRequest):
|
||||
queue = StringField()
|
||||
queue_name = StringField()
|
||||
validate_tasks = BoolField(default=False)
|
||||
verify_watched_queue = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteManyRequest(TaskBatchRequest):
|
||||
@@ -280,6 +286,7 @@ class DeleteManyRequest(TaskBatchRequest):
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class ResetManyRequest(TaskBatchRequest):
|
||||
@@ -287,6 +294,7 @@ class ResetManyRequest(TaskBatchRequest):
|
||||
return_file_urls = BoolField(default=False)
|
||||
delete_output_models = BoolField(default=True)
|
||||
force = BoolField(default=False)
|
||||
delete_external_artifacts = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishManyRequest(TaskBatchRequest):
|
||||
@@ -310,3 +318,8 @@ class DeleteModelsRequest(TaskRequest):
|
||||
models: Sequence[ModelItemKey] = ListField(
|
||||
[ModelItemKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class GetAllReq(models.Base):
|
||||
allow_public = BoolField(default=True)
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
@@ -20,12 +20,11 @@ DEFAULT_TIMEOUT = 10 * 60
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
timeout = make_default(
|
||||
IntField, DEFAULT_TIMEOUT
|
||||
)() # registration timeout in seconds (default is 10min)
|
||||
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min)
|
||||
queues = ListField(six.string_types) # list of queues this worker listens to
|
||||
|
||||
|
||||
@@ -76,6 +75,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
@@ -97,6 +97,7 @@ class WorkerResponseEntry(WorkerEntry):
|
||||
class GetAllRequest(Base):
|
||||
last_seen = IntField(default=3600)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class GetAllResponse(Base):
|
||||
|
||||
@@ -9,6 +9,7 @@ from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
import elasticsearch
|
||||
from boltons.iterutils import chunked_iter
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
@@ -26,11 +27,12 @@ from apiserver.bll.event.event_common import (
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
|
||||
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator
|
||||
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_metrics import EventMetrics
|
||||
@@ -39,9 +41,8 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items, nested_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@@ -63,12 +64,21 @@ class PlotFields:
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
event_id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
img_source_regex = re.compile(
|
||||
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
_task_event_query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {"model_event": False}},
|
||||
{"bool": {"must_not": [{"exists": {"field": "model_event"}}]}},
|
||||
]
|
||||
}
|
||||
}
|
||||
_model_event_query = {"term": {"model_event": True}}
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@@ -84,7 +94,7 @@ class EventBLL(object):
|
||||
es=self.es, redis=self.redis
|
||||
)
|
||||
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
|
||||
self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.events_iterator = EventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
@@ -97,18 +107,42 @@ class EventBLL(object):
|
||||
if not task_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
with translate_errors_context():
|
||||
query = Q(id__in=task_ids, company=company_id)
|
||||
if not allow_locked_tasks:
|
||||
query &= Q(status__nin=LOCKED_TASK_STATUSES)
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not model_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context():
|
||||
query = Q(id__in=model_ids, company=company_id)
|
||||
if not allow_locked_models:
|
||||
query &= Q(ready__ne=True)
|
||||
res = Model.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
self, company_id, events, worker, allow_locked=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
model_events = events[0].get("model_event", False)
|
||||
for event in events:
|
||||
if event.get("model_event", model_events) != model_events:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Inconsistent model_event setting in the passed events"
|
||||
)
|
||||
if event.pop("allow_locked", allow_locked) != allow_locked:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Inconsistent allow_locked setting in the passed events"
|
||||
)
|
||||
|
||||
actions: List[dict] = []
|
||||
task_ids = set()
|
||||
task_or_model_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_scalar_events = nested_dict(
|
||||
3, dict
|
||||
@@ -118,13 +152,28 @@ class EventBLL(object):
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
errors_per_type = defaultdict(int)
|
||||
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
if model_events:
|
||||
for event in events:
|
||||
model = event.pop("model", None)
|
||||
if model is not None:
|
||||
event["task"] = model
|
||||
valid_entities = self._get_valid_models(
|
||||
company_id,
|
||||
model_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_models=allow_locked,
|
||||
)
|
||||
entity_name = "model"
|
||||
else:
|
||||
valid_entities = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked,
|
||||
)
|
||||
entity_name = "task"
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
@@ -138,13 +187,17 @@ class EventBLL(object):
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
if model_events and event_type == EventType.task_log.value:
|
||||
errors_per_type[f"Task log events are not supported for models"] += 1
|
||||
continue
|
||||
|
||||
task_or_model_id = event.get("task")
|
||||
if task_or_model_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
if task_or_model_id not in valid_entities:
|
||||
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
@@ -179,6 +232,7 @@ class EventBLL(object):
|
||||
|
||||
event["metric"] = event.get("metric") or ""
|
||||
event["variant"] = event.get("variant") or ""
|
||||
event["model_event"] = model_events
|
||||
|
||||
index_name = get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
@@ -193,21 +247,26 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_ids.add(task_id)
|
||||
task_or_model_ids.add(task_or_model_id)
|
||||
if (
|
||||
iter is not None
|
||||
and not model_events
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
task_iteration[task_or_model_id] = max(
|
||||
iter, task_iteration[task_or_model_id]
|
||||
)
|
||||
|
||||
if not model_events:
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_or_model_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id],
|
||||
event=event,
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
plot_actions = [
|
||||
@@ -228,39 +287,41 @@ class EventBLL(object):
|
||||
with translate_errors_context():
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
with TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
elasticsearch.helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
elasticsearch.helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += 1
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
if not model_events:
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
for task_or_model_id in task_or_model_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
task_id=task_or_model_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
iter_max=task_iteration.get(task_or_model_id),
|
||||
last_scalar_events=task_last_scalar_events.get(
|
||||
task_or_model_id
|
||||
),
|
||||
last_events=task_last_events.get(task_or_model_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
remaining_tasks.add(task_or_model_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
@@ -361,8 +422,22 @@ class EventBLL(object):
|
||||
for k in ("value", "metric", "variant", "iter", "timestamp")
|
||||
if k in event
|
||||
}
|
||||
event_data["min_value"] = min(value, last_event.get("min_value", value))
|
||||
event_data["max_value"] = max(value, last_event.get("max_value", value))
|
||||
last_event_min_value = last_event.get("min_value", value)
|
||||
last_event_min_value_iter = last_event.get("min_value_iter", event_iter)
|
||||
if value < last_event_min_value:
|
||||
event_data["min_value"] = value
|
||||
event_data["min_value_iter"] = event_iter
|
||||
else:
|
||||
event_data["min_value"] = last_event_min_value
|
||||
event_data["min_value_iter"] = last_event_min_value_iter
|
||||
last_event_max_value = last_event.get("max_value", value)
|
||||
last_event_max_value_iter = last_event.get("max_value_iter", event_iter)
|
||||
if value > last_event_max_value:
|
||||
event_data["max_value"] = value
|
||||
event_data["max_value_iter"] = event_iter
|
||||
else:
|
||||
event_data["max_value"] = last_event_max_value
|
||||
event_data["max_value_iter"] = last_event_max_value_iter
|
||||
last_events[metric_hash][variant_hash] = event_data
|
||||
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
@@ -396,36 +471,20 @@ class EventBLL(object):
|
||||
as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
|
||||
update time.
|
||||
"""
|
||||
fields = {}
|
||||
|
||||
if iter_max is not None:
|
||||
fields["last_iteration_max"] = iter_max
|
||||
|
||||
if last_scalar_events:
|
||||
fields["last_scalar_values"] = list(
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if last_events:
|
||||
fields["last_events"] = last_events
|
||||
|
||||
if not fields:
|
||||
if iter_max is None and not last_events and not last_scalar_events:
|
||||
return False
|
||||
|
||||
return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
|
||||
return TaskBLL.update_statistics(
|
||||
task_id,
|
||||
company_id,
|
||||
last_update=now,
|
||||
last_iteration_max=iter_max,
|
||||
last_scalar_events=last_scalar_events,
|
||||
last_events=last_events,
|
||||
)
|
||||
|
||||
def _get_event_id(self, event):
|
||||
id_values = (str(event[field]) for field in self.id_fields if field in event)
|
||||
id_values = (str(event[field]) for field in self.event_id_fields if field in event)
|
||||
return hashlib.md5("-".join(id_values).encode()).hexdigest()
|
||||
|
||||
def scroll_task_events(
|
||||
@@ -441,7 +500,7 @@ class EventBLL(object):
|
||||
return [], scroll_id, 0
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
size = min(batch_size, 10000)
|
||||
@@ -454,7 +513,7 @@ class EventBLL(object):
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
@@ -469,31 +528,46 @@ class EventBLL(object):
|
||||
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
def get_task_plots(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
last_iterations_per_plot: int,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
event_type = EventType.metrics_plot
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
must = [{"term": {"task": task_id}}]
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition, {"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type
|
||||
es=self.es, company_id=company_id, event_type=event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // num_last_iterations)
|
||||
max_variants = int(max_variants // last_iterations_per_plot)
|
||||
|
||||
es_req: dict = {
|
||||
es_req = {
|
||||
"sort": [{"iter": {"order": "desc"}}],
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@@ -509,11 +583,10 @@ class EventBLL(object):
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_key": "desc"},
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": last_iterations_per_plot
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -521,117 +594,28 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
with translate_errors_context():
|
||||
es_response = search_company_events(
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
**search_args,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
return [
|
||||
(metric["key"], variant["key"], iter["key"])
|
||||
for metric in es_res["aggregations"]["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
for iter in variant["iters"]["buckets"]
|
||||
]
|
||||
|
||||
def get_task_plots(
|
||||
self,
|
||||
company_id: str,
|
||||
tasks: Sequence[str],
|
||||
last_iterations_per_plot: int = None,
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
aggs_result = es_response.get("aggregations")
|
||||
if not aggs_result:
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = EventType.metrics_plot
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
|
||||
for metric, variant, iter in last_iters:
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"term": {"iter": iter}},
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
events = [
|
||||
hit["_source"]
|
||||
for metrics_bucket in aggs_result["metrics"]["buckets"]
|
||||
for variants_bucket in metrics_bucket["variants"]["buckets"]
|
||||
for hit in variants_bucket["events"]["hits"]["hits"]
|
||||
]
|
||||
self.uncompress_plots(events)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
events=events, total_events=len(events)
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
@@ -643,7 +627,7 @@ class EventBLL(object):
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
self.clear_scroll(next_scroll_id)
|
||||
next_scroll_id = self.empty_scroll
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
@@ -721,11 +705,10 @@ class EventBLL(object):
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
task_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
metric=None,
|
||||
variant=None,
|
||||
metrics: MetricVariants = None,
|
||||
last_iter_count=None,
|
||||
sort=None,
|
||||
size=500,
|
||||
@@ -736,28 +719,38 @@ class EventBLL(object):
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||
company_ids = [
|
||||
c_id
|
||||
for c_id in set(company_ids)
|
||||
if not check_empty_data(self.es, c_id, event_type)
|
||||
]
|
||||
if not company_ids:
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
task_ids = (
|
||||
[task_id]
|
||||
if isinstance(task_id, str)
|
||||
else task_id
|
||||
)
|
||||
|
||||
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
|
||||
if last_iter_count is None:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
metrics=metrics,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
@@ -784,10 +777,10 @@ class EventBLL(object):
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
@@ -809,9 +802,7 @@ class EventBLL(object):
|
||||
return {}
|
||||
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
@@ -838,9 +829,7 @@ class EventBLL(object):
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
metrics = {}
|
||||
@@ -867,9 +856,7 @@ class EventBLL(object):
|
||||
]
|
||||
}
|
||||
}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
@@ -915,9 +902,7 @@ class EventBLL(object):
|
||||
},
|
||||
"_source": {"excludes": []},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
metrics = []
|
||||
@@ -963,7 +948,7 @@ class EventBLL(object):
|
||||
"_source": ["iter", "value"],
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
@@ -978,15 +963,26 @@ class EventBLL(object):
|
||||
|
||||
def get_last_iters(
|
||||
self,
|
||||
company_id: str,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
metrics: MetricVariants = None
|
||||
) -> Mapping[str, Sequence]:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||
company_ids = [
|
||||
c_id
|
||||
for c_id in set(company_ids)
|
||||
if not check_empty_data(self.es, c_id, event_type)
|
||||
]
|
||||
if not company_ids:
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
must = [{"terms": {"task": task_ids}}]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@@ -1003,12 +999,12 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
self.es, company_id=company_ids, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
@@ -1019,6 +1015,21 @@ class EventBLL(object):
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _validate_model_state(
|
||||
company_id: str, model_id: str, allow_locked: bool = False
|
||||
):
|
||||
extra_msg = None
|
||||
query = Q(id=model_id, company=company_id)
|
||||
if not allow_locked:
|
||||
query &= Q(ready__ne=True)
|
||||
extra_msg = "or model published"
|
||||
res = Model.objects(query).only("id").first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
extra_msg, company=company_id, id=model_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
|
||||
extra_msg = None
|
||||
@@ -1032,22 +1043,42 @@ class EventBLL(object):
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
@staticmethod
|
||||
def _get_events_deletion_params(async_delete: bool) -> dict:
|
||||
if async_delete:
|
||||
return {
|
||||
"wait_for_completion": False,
|
||||
"requests_per_second": config.get(
|
||||
"services.events.max_async_deleted_events_per_sec", 1000
|
||||
),
|
||||
}
|
||||
|
||||
return {"refresh": True}
|
||||
|
||||
def delete_task_events(
|
||||
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
|
||||
):
|
||||
if model:
|
||||
self._validate_model_state(
|
||||
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
|
||||
)
|
||||
else:
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
with translate_errors_context():
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
**self._get_events_deletion_params(async_delete),
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
if not async_delete:
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def clear_task_log(
|
||||
self,
|
||||
@@ -1064,7 +1095,7 @@ class EventBLL(object):
|
||||
):
|
||||
return 0
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "clear_task_log"):
|
||||
with translate_errors_context():
|
||||
must = [{"term": {"task": task_id}}]
|
||||
sort = None
|
||||
if threshold_sec:
|
||||
@@ -1092,24 +1123,29 @@ class EventBLL(object):
|
||||
)
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]):
|
||||
def delete_multi_task_events(
|
||||
self, company_id: str, task_ids: Sequence[str], async_delete=False
|
||||
):
|
||||
"""
|
||||
Delete mutliple task events. No check is done for tasks write access
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
es_req = {"query": {"terms": {"task": task_ids}}}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "delete_multi_tasks_events"
|
||||
):
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
)
|
||||
deleted = 0
|
||||
with translate_errors_context():
|
||||
for tasks in chunked_iter(task_ids, 100):
|
||||
es_req = {"query": {"terms": {"task": tasks}}}
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
**self._get_events_deletion_params(async_delete),
|
||||
)
|
||||
if not async_delete:
|
||||
deleted += es_res.get("deleted", 0)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
if not async_delete:
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def clear_scroll(self, scroll_id: str):
|
||||
if scroll_id == self.empty_scroll:
|
||||
|
||||
@@ -8,7 +8,7 @@ from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
|
||||
@@ -21,8 +21,9 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
SINGLE_SCALAR_ITERATION = -2**31
|
||||
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
TaskCompanies = Mapping[str, Sequence[Task]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@@ -53,9 +54,12 @@ class EventSettings:
|
||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id.lower()}"
|
||||
if isinstance(company_id, str):
|
||||
company_id = [company_id]
|
||||
|
||||
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
@@ -80,9 +84,7 @@ def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(
|
||||
index=es_index, body=body, conflicts="proceed", **kwargs
|
||||
)
|
||||
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
|
||||
|
||||
|
||||
def count_company_events(
|
||||
@@ -116,9 +118,7 @@ def get_max_metric_and_variant_counts(
|
||||
"query": query,
|
||||
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_max_metric_and_variant_counts"
|
||||
):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
|
||||
)
|
||||
|
||||
@@ -8,9 +8,7 @@ from typing import Sequence, Tuple, Mapping
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
@@ -20,12 +18,11 @@ from apiserver.bll.event.event_common import (
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
SINGLE_SCALAR_ITERATION,
|
||||
TaskCompanies,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -111,57 +108,51 @@ class EventMetrics:
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
companies: TaskCompanies,
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company", "company_origin"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.get_index_company() for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
event_type = EventType.metrics_scalar
|
||||
company_id = next(iter(companies))
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
companies = {
|
||||
company_id: tasks
|
||||
for company_id, tasks in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return {}
|
||||
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
run_parallel=False,
|
||||
)
|
||||
task_ids, company_ids = zip(
|
||||
*(
|
||||
(t.id, t.company)
|
||||
for t in itertools.chain.from_iterable(companies.values())
|
||||
)
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
|
||||
)
|
||||
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
task_name = task_names[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
@@ -170,18 +161,27 @@ class EventMetrics:
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
self, companies: TaskCompanies
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
if check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
):
|
||||
companies = {
|
||||
company_id: [t.id for t in tasks]
|
||||
for company_id, tasks in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return {}
|
||||
|
||||
with TimingContext("es", "get_task_single_value_metrics"):
|
||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_events = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(self._get_task_single_value_metrics, companies.items())
|
||||
),
|
||||
)
|
||||
|
||||
def _get_value(event: dict):
|
||||
return {
|
||||
@@ -195,8 +195,9 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
self, tasks: Tuple[str, Sequence[str]]
|
||||
) -> Sequence[dict]:
|
||||
company_id, task_ids = tasks
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
@@ -277,9 +278,7 @@ class EventMetrics:
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
@@ -312,8 +311,7 @@ class EventMetrics:
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
@@ -366,9 +364,7 @@ class EventMetrics:
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
@@ -493,10 +489,9 @@ class EventMetrics:
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
||||
@@ -17,7 +17,6 @@ from apiserver.bll.event.event_common import (
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -76,7 +75,7 @@ class EventsIterator:
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "count_task_events"):
|
||||
with translate_errors_context():
|
||||
es_result = count_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
@@ -113,7 +112,7 @@ class EventsIterator:
|
||||
if from_key_value:
|
||||
es_req["search_after"] = [from_key_value]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
with translate_errors_context():
|
||||
es_result = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
@@ -1,56 +1,455 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, ListField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import EventType
|
||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
||||
from .event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class HistoryDebugImageIterator(HistorySampleIterator):
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugImageSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class VariantSampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryDebugImageIterator:
|
||||
event_type = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_image)
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return [{"exists": {"field": "url"}}]
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = VariantSampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
variants_conditions = [
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if next_iteration:
|
||||
event = self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
else:
|
||||
# noinspection PyArgumentList
|
||||
event = first(
|
||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||
for f in (
|
||||
self._get_next_for_current_iteration,
|
||||
self._get_next_for_another_iteration,
|
||||
)
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
event: dict, res: VariantSampleResult, state: DebugImageSampleState
|
||||
):
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@staticmethod
|
||||
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
|
||||
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in metric_variants
|
||||
]
|
||||
return {"bool": {"should": variants_conditions}}
|
||||
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"gte": v.min_iteration}}},
|
||||
{"term": {"metric": metric}},
|
||||
_get_variants_conditions(metric_variants),
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
for metric, metric_variants in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": variants_conditions}}
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
return event
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
||||
aggs = {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = VariantSampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugImageSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugImageSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugImageSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return min_iter, max_iter
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return aggs, get_min_max_data
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
||||
|
||||
|
||||
class HistoryPlotIterator(HistorySampleIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
return {"terms": {"variant": [v.name for v in variants]}}
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
||||
aggs = {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
}
|
||||
|
||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
||||
min_iter = int(variant_bucket["first_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return min_iter, max_iter
|
||||
|
||||
return aggs, get_min_max_data
|
||||
316
apiserver/bll/event/history_plots_iterator.py
Normal file
316
apiserver/bll/event/history_plots_iterator.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, ListField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import (
|
||||
EventType,
|
||||
uncompress_plot,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
name: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class PlotsSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
metric_states: Sequence[MetricState] = ListField([MetricState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricSamplesResult(object):
|
||||
scroll_id: str = None
|
||||
events: list = []
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryPlotsIterator:
|
||||
event_type = EventType.metrics_plot
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=PlotsSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the samples for next/prev metric on the current iteration
|
||||
If does not exist then try getting sample for the first/last metric from next/prev iteration
|
||||
"""
|
||||
res = MetricSamplesResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
]
|
||||
if state.navigate_current_metric:
|
||||
must_conditions.append({"term": {"metric": state.metric}})
|
||||
|
||||
next_iteration_condition = {
|
||||
"range": {"iter": {range_operator: state.iteration}}
|
||||
}
|
||||
if next_iteration or state.navigate_current_metric:
|
||||
must_conditions.append(next_iteration_condition)
|
||||
else:
|
||||
next_metric_condition = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"range": {"metric": {range_operator: state.metric}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
must_conditions.append(
|
||||
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
|
||||
)
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order=order,
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def get_samples_for_metric(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = MetricSamplesResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: PlotsSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: PlotsSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
state: PlotsSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
|
||||
if not metric_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order="desc",
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
return res
|
||||
|
||||
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
|
||||
metrics = self._get_metric_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.metric_states = [
|
||||
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for metric, (min_iter, max_iter) in metrics.items()
|
||||
]
|
||||
|
||||
def _get_metric_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 5000,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
)
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: (
|
||||
int(metric_bucket["first_iter"]["value"]),
|
||||
int(metric_bucket["last_iter"]["value"]),
|
||||
)
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
|
||||
):
|
||||
for event in events:
|
||||
uncompress_plot(event)
|
||||
state.metric = events[0]["metric"]
|
||||
state.iteration = events[0]["iter"]
|
||||
res.events = events
|
||||
metric_state = first(
|
||||
ms for ms in state.metric_states if ms.name == state.metric
|
||||
)
|
||||
if metric_state:
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
def _get_metric_events_for_condition(
|
||||
self, company_id: str, task: str, order: str, must_conditions: Sequence
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1,
|
||||
"order": {"_key": order},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"variant": {"order": "asc"}},
|
||||
"size": 100,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
for level in ("iters", "metrics"):
|
||||
level_data = aggs_result[level]["buckets"]
|
||||
if not level_data:
|
||||
return []
|
||||
aggs_result = level_data[0]
|
||||
|
||||
return [
|
||||
hit["_source"]
|
||||
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
|
||||
]
|
||||
@@ -1,442 +0,0 @@
|
||||
import abc
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Callable, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class HistorySampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
reached_first: bool = BoolField()
|
||||
reached_last: bool = BoolField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class HistorySampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistorySampleIterator(abc.ABC):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=HistorySampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
|
||||
) -> HistorySampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = HistorySampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
event = self._get_next_for_current_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
) or self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def _fill_res_and_update_state(
|
||||
self, event: dict, res: HistorySampleResult, state: HistorySampleState
|
||||
):
|
||||
self._process_event(event)
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
pass
|
||||
|
||||
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
self._get_variants_conditions(vs),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, vs in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_variants_condition(variants),
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_current_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_variants_condition(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_another_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> HistorySampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = HistorySampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: HistorySampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: HistorySampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: HistorySampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_history_sample_for_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
pass
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[str, str, int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": min_max_aggs,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_history_sample_iterations"
|
||||
):
|
||||
es_res = search_company_events(body=es_req, **search_args,)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
min_iter, max_iter = get_min_max_data(variant_bucket)
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
@@ -19,14 +19,14 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition, get_max_metric_and_variant_counts,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
@@ -75,18 +75,25 @@ class MetricEventsIterator:
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
companies: Mapping[str, str],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> MetricEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.event_type):
|
||||
companies = {
|
||||
task_id: company_id
|
||||
for task_id, company_id in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return MetricEventsResult()
|
||||
|
||||
def init_state(state_: MetricEventsScrollState):
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
state_.tasks = self._init_task_states(companies, task_metrics)
|
||||
|
||||
def validate_state(state_: MetricEventsScrollState):
|
||||
"""
|
||||
@@ -95,7 +102,7 @@ class MetricEventsIterator:
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if refresh:
|
||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
||||
self._reinit_outdated_task_states(companies, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
@@ -112,7 +119,7 @@ class MetricEventsIterator:
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
company_id=company_id,
|
||||
companies=companies,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
specific_variants_requested=specific_variants_requested,
|
||||
@@ -125,7 +132,7 @@ class MetricEventsIterator:
|
||||
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
companies: Mapping[str, str],
|
||||
state: MetricEventsScrollState,
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
@@ -133,9 +140,7 @@ class MetricEventsIterator:
|
||||
Determine the metrics for which new event_type events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
|
||||
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
@@ -175,7 +180,7 @@ class MetricEventsIterator:
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
||||
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
@@ -205,14 +210,14 @@ class MetricEventsIterator:
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
||||
partial(self._init_metric_states_for_task, companies=companies),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
@@ -226,17 +231,20 @@ class MetricEventsIterator:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
def _get_variant_state_aggs(
|
||||
self,
|
||||
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
pass
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any event_type events
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
company_id = companies[task]
|
||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
@@ -268,14 +276,18 @@ class MetricEventsIterator:
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
**({"aggs": variant_state_aggs} if variant_state_aggs else {}),
|
||||
**(
|
||||
{"aggs": variant_state_aggs}
|
||||
if variant_state_aggs
|
||||
else {}
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
@@ -313,7 +325,7 @@ class MetricEventsIterator:
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
companies: Mapping[str, str],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
specific_variants_requested: bool,
|
||||
@@ -383,10 +395,10 @@ class MetricEventsIterator:
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metric_events"):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
company_id=companies[task_state.task],
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple, Sequence, Dict
|
||||
from typing import Callable, Tuple, Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
@@ -24,13 +26,41 @@ class ModelBLL:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id,
|
||||
model_ids,
|
||||
only=None,
|
||||
allow_public=False,
|
||||
return_models=True,
|
||||
) -> Optional[Sequence[Model]]:
|
||||
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
||||
ids = set(model_ids)
|
||||
query = Q(id__in=ids)
|
||||
|
||||
q = Model.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidModelId(ids=model_ids)
|
||||
|
||||
if return_models:
|
||||
return list(q)
|
||||
|
||||
@classmethod
|
||||
def publish_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, bool], dict] = None,
|
||||
publish_task_func: Callable[[str, str, str, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
@@ -45,7 +75,7 @@ class ModelBLL:
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, force_publish_task
|
||||
model.task, company_id, user_id, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
|
||||
@@ -11,7 +11,6 @@ from apiserver.utilities.parameter_key_escaper import (
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -42,27 +41,25 @@ class Metadata:
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_metadata"):
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
with TimingContext("mongo", "delete_metadata"):
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from itertools import groupby, chain
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Sequence,
|
||||
@@ -22,6 +21,7 @@ from mongoengine import Q, Document
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.projects import ProjectChildrenType
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
@@ -29,7 +29,6 @@ from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
from apiserver.database.utils import get_options, get_company_or_none_constraint
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .sub_projects import (
|
||||
_reposition_project_with_children,
|
||||
@@ -41,13 +40,22 @@ from .sub_projects import (
|
||||
_ids_with_children,
|
||||
_ids_with_parents,
|
||||
_get_project_depth,
|
||||
ProjectsChildren,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
max_depth = config.get("services.projects.sub_projects.max_depth", 10)
|
||||
reports_project_name = ".reports"
|
||||
datasets_project_name = ".datasets"
|
||||
pipelines_project_name = ".pipelines"
|
||||
reports_tag = "reports"
|
||||
dataset_tag = "dataset"
|
||||
pipeline_tag = "pipeline"
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
child_classes = (Task, Model)
|
||||
|
||||
@classmethod
|
||||
def merge_project(
|
||||
cls, company, source_id: str, destination_id: str
|
||||
@@ -57,54 +65,53 @@ class ProjectBLL:
|
||||
Remove the source project
|
||||
Return the amounts of moved entities and subprojects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "move_project"):
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
source=source_id
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
source=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
if destination_id:
|
||||
destination = Project.get(company, destination_id)
|
||||
if source_id in destination.path:
|
||||
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
|
||||
source=source_id, destination=destination_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
if destination_id:
|
||||
destination = Project.get(company, destination_id)
|
||||
if source_id in destination.path:
|
||||
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
|
||||
source=source_id, destination=destination_id
|
||||
)
|
||||
else:
|
||||
destination = None
|
||||
else:
|
||||
destination = None
|
||||
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
)[source.id]
|
||||
if destination:
|
||||
cls.validate_projects_depth(
|
||||
projects=children,
|
||||
old_parent_depth=len(source.path) + 1,
|
||||
new_parent_depth=len(destination.path) + 1,
|
||||
)
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
)[source.id]
|
||||
if destination:
|
||||
cls.validate_projects_depth(
|
||||
projects=children,
|
||||
old_parent_depth=len(source.path) + 1,
|
||||
new_parent_depth=len(destination.path) + 1,
|
||||
)
|
||||
|
||||
moved_entities = 0
|
||||
for entity_type in (Task, Model):
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
moved_entities = 0
|
||||
for entity_type in cls.child_classes:
|
||||
moved_entities += entity_type.objects(
|
||||
company=company,
|
||||
project=source_id,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).update(upsert=False, project=destination_id)
|
||||
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(
|
||||
project=child,
|
||||
children=[c for c in children if c.parent == child.id],
|
||||
parent=destination,
|
||||
)
|
||||
moved_sub_projects += 1
|
||||
moved_sub_projects = 0
|
||||
for child in Project.objects(company=company, parent=source_id):
|
||||
_reposition_project_with_children(
|
||||
project=child,
|
||||
children=[c for c in children if c.parent == child.id],
|
||||
parent=destination,
|
||||
)
|
||||
moved_sub_projects += 1
|
||||
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
affected = {source.id, *(source.path or [])}
|
||||
source.delete()
|
||||
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
if destination:
|
||||
destination.update(last_update=datetime.utcnow())
|
||||
affected.update({destination.id, *(destination.path or [])})
|
||||
|
||||
return moved_entities, moved_sub_projects, affected
|
||||
|
||||
@@ -127,78 +134,76 @@ class ProjectBLL:
|
||||
it should be writable. The source location should be writable too.
|
||||
Return the number of moved projects + set of all the affected project ids
|
||||
"""
|
||||
with TimingContext("mongo", "move_project"):
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
project = Project.get(company, project_id)
|
||||
old_parent_id = project.parent
|
||||
old_parent = (
|
||||
Project.get_for_writing(company=project.company, id=old_parent_id)
|
||||
if old_parent_id
|
||||
else None
|
||||
)
|
||||
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
cls.validate_projects_depth(
|
||||
projects=[project, *children],
|
||||
old_parent_depth=len(project.path),
|
||||
new_parent_depth=_get_project_depth(new_location),
|
||||
)
|
||||
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
cls.validate_projects_depth(
|
||||
projects=[project, *children],
|
||||
old_parent_depth=len(project.path),
|
||||
new_parent_depth=_get_project_depth(new_location),
|
||||
if new_parent and (
|
||||
project_id == new_parent.id or project_id in new_parent.path
|
||||
):
|
||||
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
|
||||
project=project_id, parent=new_parent.id
|
||||
)
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
|
||||
new_parent = _ensure_project(company=company, user=user, name=new_location)
|
||||
new_parent_id = new_parent.id if new_parent else None
|
||||
if old_parent_id == new_parent_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
if new_parent and (
|
||||
project_id == new_parent.id or project_id in new_parent.path
|
||||
):
|
||||
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
|
||||
project=project_id, parent=new_parent.id
|
||||
)
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
now = datetime.utcnow()
|
||||
affected = set()
|
||||
for p in filter(None, (old_parent, new_parent)):
|
||||
p.update(last_update=now)
|
||||
affected.update({p.id, *(p.path or [])})
|
||||
|
||||
return moved, affected
|
||||
return moved, affected
|
||||
|
||||
@classmethod
|
||||
def update(cls, company: str, project_id: str, **fields):
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
project = Project.get_for_writing(company=company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
fields["basename"] = new_name.split("/")[-1]
|
||||
new_name = fields.pop("name", None)
|
||||
if new_name:
|
||||
new_name, new_location = _validate_project_name(new_name)
|
||||
old_name, old_location = _validate_project_name(project.name)
|
||||
if new_location != old_location:
|
||||
raise errors.bad_request.CannotUpdateProjectLocation(name=new_name)
|
||||
fields["name"] = new_name
|
||||
fields["basename"] = new_name.split("/")[-1]
|
||||
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
if new_name:
|
||||
old_name = project.name
|
||||
project.name = new_name
|
||||
children = _get_sub_projects(
|
||||
[project.id], _only=("id", "name", "path")
|
||||
)[project.id]
|
||||
_update_subproject_names(
|
||||
project=project, children=children, old_name=old_name
|
||||
)
|
||||
if new_name:
|
||||
old_name = project.name
|
||||
project.name = new_name
|
||||
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
|
||||
project.id
|
||||
]
|
||||
_update_subproject_names(
|
||||
project=project, children=children, old_name=old_name
|
||||
)
|
||||
|
||||
return updated
|
||||
return updated
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -301,7 +306,7 @@ class ProjectBLL:
|
||||
"""
|
||||
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
|
||||
"""
|
||||
with TimingContext("mongo", "move_under_project"):
|
||||
if project_name or project:
|
||||
project = cls.find_or_create(
|
||||
user=user,
|
||||
company=company,
|
||||
@@ -309,16 +314,17 @@ class ProjectBLL:
|
||||
project_name=project_name,
|
||||
description="",
|
||||
)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
|
||||
return project
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||
@@ -399,6 +405,18 @@ class ProjectBLL:
|
||||
"$completed",
|
||||
{"$gt": ["$completed", time_thresh]},
|
||||
additional_cond,
|
||||
{
|
||||
"$not": {
|
||||
"$in": [
|
||||
"$status",
|
||||
[
|
||||
TaskStatus.queued,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.failed,
|
||||
],
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
},
|
||||
"then": 1,
|
||||
@@ -512,7 +530,7 @@ class ProjectBLL:
|
||||
def aggregate_project_data(
|
||||
func: Callable[[T, T], T],
|
||||
project_ids: Sequence[str],
|
||||
child_projects: Mapping[str, Sequence[Project]],
|
||||
child_projects: ProjectsChildren,
|
||||
data: Mapping[str, T],
|
||||
) -> Dict[str, T]:
|
||||
"""
|
||||
@@ -564,6 +582,136 @@ class ProjectBLL:
|
||||
for r in Task.aggregate(task_runtime_pipeline)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_projects_children(
|
||||
project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str],
|
||||
) -> Tuple[ProjectsChildren, Set[str]]:
|
||||
child_projects = _get_sub_projects(
|
||||
project_ids,
|
||||
_only=("id", "name"),
|
||||
search_hidden=search_hidden,
|
||||
allowed_ids=allowed_ids,
|
||||
)
|
||||
return (
|
||||
child_projects,
|
||||
{c.id for c in chain.from_iterable(child_projects.values())},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_children_info(
|
||||
project_ids: Sequence[str], child_projects: ProjectsChildren
|
||||
) -> dict:
|
||||
return {
|
||||
project: sorted(
|
||||
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
|
||||
key=itemgetter("name"),
|
||||
)
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_project_dataset_stats_core(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
project_field: str,
|
||||
entity_class: Type[AttributedDocument],
|
||||
include_children: bool = True,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
selected_project_ids: Sequence[str] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = {}
|
||||
project_ids_with_children = set(project_ids)
|
||||
if include_children:
|
||||
child_projects, children_ids = cls._get_projects_children(
|
||||
project_ids, search_hidden=True, allowed_ids=selected_project_ids,
|
||||
)
|
||||
project_ids_with_children |= children_ids
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
project_field=project_field,
|
||||
)
|
||||
},
|
||||
{"$project": {project_field: 1, "tags": 1}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": f"${project_field}",
|
||||
"count": {"$sum": 1},
|
||||
"tags": {"$push": "$tags"},
|
||||
}
|
||||
},
|
||||
]
|
||||
res = entity_class.aggregate(pipeline)
|
||||
|
||||
project_stats = {
|
||||
result["_id"]: {
|
||||
"count": result.get("count", 0),
|
||||
"tags": set(chain.from_iterable(result.get("tags", []))),
|
||||
}
|
||||
for result in res
|
||||
}
|
||||
|
||||
def concat_dataset_stats(a: dict, b: dict) -> dict:
|
||||
return {
|
||||
"count": a.get("count", 0) + b.get("count", 0),
|
||||
"tags": a.get("tags", {}) | b.get("tags", {}),
|
||||
}
|
||||
|
||||
top_project_stats = cls.aggregate_project_data(
|
||||
func=concat_dataset_stats,
|
||||
project_ids=project_ids,
|
||||
child_projects=child_projects,
|
||||
data=project_stats,
|
||||
)
|
||||
for _, stat in top_project_stats.items():
|
||||
stat["tags"] = sorted(list(stat.get("tags", {})))
|
||||
|
||||
empty_stats = {"count": 0, "tags": []}
|
||||
stats = {
|
||||
project: {"datasets": top_project_stats.get(project, empty_stats)}
|
||||
for project in project_ids
|
||||
}
|
||||
return stats, cls._get_children_info(project_ids, child_projects)
|
||||
|
||||
@classmethod
|
||||
def get_project_dataset_stats(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
include_children: bool = True,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
selected_project_ids: Sequence[str] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
filter_ = filter_ or {}
|
||||
filter_system_tags = filter_.get("system_tags")
|
||||
if not isinstance(filter_system_tags, list):
|
||||
filter_system_tags = []
|
||||
if dataset_tag not in filter_system_tags:
|
||||
filter_system_tags.append(dataset_tag)
|
||||
filter_["system_tags"] = filter_system_tags
|
||||
|
||||
return cls._get_project_dataset_stats_core(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
project_field="parent",
|
||||
entity_class=Project,
|
||||
include_children=include_children,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
selected_project_ids=selected_project_ids,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_project_stats(
|
||||
cls,
|
||||
@@ -574,24 +722,21 @@ class ProjectBLL:
|
||||
search_hidden: bool = False,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
user_active_project_ids: Sequence[str] = None,
|
||||
selected_project_ids: Sequence[str] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = (
|
||||
_get_sub_projects(
|
||||
child_projects = {}
|
||||
project_ids_with_children = set(project_ids)
|
||||
if include_children:
|
||||
child_projects, children_ids = cls._get_projects_children(
|
||||
project_ids,
|
||||
_only=("id", "name"),
|
||||
search_hidden=search_hidden,
|
||||
allowed_ids=user_active_project_ids,
|
||||
allowed_ids=selected_project_ids,
|
||||
)
|
||||
if include_children
|
||||
else {}
|
||||
)
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
project_ids_with_children |= children_ids
|
||||
|
||||
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
@@ -695,14 +840,7 @@ class ProjectBLL:
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
children = {
|
||||
project: sorted(
|
||||
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
|
||||
key=itemgetter("name"),
|
||||
)
|
||||
for project in project_ids
|
||||
}
|
||||
return stats, children
|
||||
return stats, cls._get_children_info(project_ids, child_projects)
|
||||
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
@@ -716,22 +854,21 @@ class ProjectBLL:
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in cls.child_classes:
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_project_tags(
|
||||
@@ -741,63 +878,97 @@ class ProjectBLL:
|
||||
projects: Sequence[str] = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
with TimingContext("mongo", "get_tags_from_db"):
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
def get_projects_with_selected_children(
|
||||
cls,
|
||||
company: str,
|
||||
users: Sequence[str],
|
||||
users: Sequence[str] = None,
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
children_type: ProjectChildrenType = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Get the projects ids where user created any tasks including all the parents of these projects
|
||||
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
|
||||
including all the parents of these projects
|
||||
If project ids are specified then filter the results by these project ids
|
||||
"""
|
||||
query = Q(user__in=users)
|
||||
if not (users or children_type):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either active users or children_condition should be specified"
|
||||
)
|
||||
|
||||
if allow_public:
|
||||
query &= get_company_or_none_constraint(company)
|
||||
query = (
|
||||
get_company_or_none_constraint(company)
|
||||
if allow_public
|
||||
else Q(company=company)
|
||||
)
|
||||
if users:
|
||||
query &= Q(user__in=users)
|
||||
|
||||
project_query = None
|
||||
if children_type == ProjectChildrenType.dataset:
|
||||
child_queries = {
|
||||
Project: query
|
||||
& Q(system_tags__in=[dataset_tag], basename__ne=datasets_project_name)
|
||||
}
|
||||
elif children_type == ProjectChildrenType.pipeline:
|
||||
child_queries = {Task: query & Q(system_tags__in=[pipeline_tag])}
|
||||
elif children_type == ProjectChildrenType.report:
|
||||
child_queries = {Task: query & Q(system_tags__in=[reports_tag])}
|
||||
else:
|
||||
query &= Q(company=company)
|
||||
project_query = query
|
||||
child_queries = {entity_cls: query for entity_cls in cls.child_classes}
|
||||
|
||||
user_projects_query = query
|
||||
if project_ids:
|
||||
ids_with_children = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=ids_with_children)
|
||||
user_projects_query &= Q(id__in=ids_with_children)
|
||||
if project_query:
|
||||
project_query &= Q(id__in=ids_with_children)
|
||||
for child_cls in child_queries:
|
||||
child_queries[child_cls] &= (
|
||||
Q(parent__in=ids_with_children)
|
||||
if child_cls is Project
|
||||
else Q(project__in=ids_with_children)
|
||||
)
|
||||
|
||||
res = {p.id for p in Project.objects(user_projects_query).only("id")}
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="project"))
|
||||
res = (
|
||||
{p.id for p in Project.objects(project_query).only("id")}
|
||||
if project_query
|
||||
else set()
|
||||
)
|
||||
for cls_, query_ in child_queries.items():
|
||||
res |= set(
|
||||
cls_.objects(query_).distinct(
|
||||
field="parent" if cls_ is Project else "project"
|
||||
)
|
||||
)
|
||||
|
||||
res = list(res)
|
||||
if not res:
|
||||
return res, res
|
||||
|
||||
user_active_project_ids = _ids_with_parents(res)
|
||||
selected_project_ids = _ids_with_parents(res)
|
||||
filtered_ids = (
|
||||
list(set(user_active_project_ids) & set(project_ids))
|
||||
list(set(selected_project_ids) & set(project_ids))
|
||||
if project_ids
|
||||
else list(user_active_project_ids)
|
||||
else list(selected_project_ids)
|
||||
)
|
||||
|
||||
return filtered_ids, user_active_project_ids
|
||||
return filtered_ids, selected_project_ids
|
||||
|
||||
@classmethod
|
||||
def get_task_parents(
|
||||
@@ -870,10 +1041,11 @@ class ProjectBLL:
|
||||
project_ids: Sequence[str],
|
||||
filter_: Mapping[str, Any],
|
||||
users: Sequence[str],
|
||||
project_field: str = "project",
|
||||
):
|
||||
conditions = {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
project_field: {"$in": project_ids},
|
||||
}
|
||||
if users:
|
||||
conditions["user"] = {"$in": users}
|
||||
@@ -898,6 +1070,69 @@ class ProjectBLL:
|
||||
|
||||
return conditions
|
||||
|
||||
@classmethod
|
||||
def _calc_own_datasets_core(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
project_field: str,
|
||||
entity_class: Type[AttributedDocument],
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of hyper datasets per requested project
|
||||
"""
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
project_field=project_field,
|
||||
)
|
||||
},
|
||||
{"$project": {project_field: 1}},
|
||||
{"$group": {"_id": f"${project_field}", "count": {"$sum": 1}}},
|
||||
]
|
||||
datasets = {
|
||||
data["_id"]: data["count"] for data in entity_class.aggregate(pipeline)
|
||||
}
|
||||
|
||||
return {pid: {"own_datasets": datasets.get(pid, 0)} for pid in project_ids}
|
||||
|
||||
@classmethod
|
||||
def calc_own_datasets(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of datasets per requested project
|
||||
"""
|
||||
filter_ = filter_ or {}
|
||||
filter_system_tags = filter_.get("system_tags")
|
||||
if not isinstance(filter_system_tags, list):
|
||||
filter_system_tags = []
|
||||
if dataset_tag not in filter_system_tags:
|
||||
filter_system_tags.append(dataset_tag)
|
||||
filter_["system_tags"] = filter_system_tags
|
||||
|
||||
return cls._calc_own_datasets_core(
|
||||
company=company,
|
||||
project_ids=project_ids,
|
||||
project_field="parent",
|
||||
entity_class=Project,
|
||||
filter_=filter_,
|
||||
users=users,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(
|
||||
cls,
|
||||
@@ -930,10 +1165,9 @@ class ProjectBLL:
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||
|
||||
with TimingContext("mongo", "get_security_groups"):
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Set, Sequence
|
||||
|
||||
import attr
|
||||
@@ -8,17 +9,19 @@ from apiserver.bll.task.task_cleanup import (
|
||||
collect_debug_image_urls,
|
||||
collect_plot_image_urls,
|
||||
TaskUrls,
|
||||
_schedule_for_delete,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
|
||||
from apiserver.timing_context import TimingContext
|
||||
from .project_bll import ProjectBLL
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -39,9 +42,9 @@ def validate_project_delete(company: str, project_id: str):
|
||||
is_pipeline = "pipeline" in (project.system_tags or [])
|
||||
project_ids = _ids_with_children([project_id])
|
||||
ret = {}
|
||||
for cls in (Task, Model):
|
||||
for cls in ProjectBLL.child_classes:
|
||||
ret[f"{cls.__name__.lower()}s"] = cls.objects(project__in=project_ids).count()
|
||||
for cls in (Task, Model):
|
||||
for cls in ProjectBLL.child_classes:
|
||||
query = dict(
|
||||
project__in=project_ids, system_tags__nin=[EntityVisibility.archived.value]
|
||||
)
|
||||
@@ -59,13 +62,22 @@ def validate_project_delete(company: str, project_id: str):
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str, project_id: str, force: bool, delete_contents: bool
|
||||
company: str,
|
||||
user: str,
|
||||
project_id: str,
|
||||
force: bool,
|
||||
delete_contents: bool,
|
||||
delete_external_artifacts=True,
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path", "system_tags")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", False
|
||||
)
|
||||
is_pipeline = "pipeline" in (project.system_tags or [])
|
||||
project_ids = _ids_with_children([project_id])
|
||||
if not force:
|
||||
@@ -88,17 +100,28 @@ def delete_project(
|
||||
)
|
||||
|
||||
if not delete_contents:
|
||||
with TimingContext("mongo", "update_children"):
|
||||
for cls in (Model, Task):
|
||||
updated_count = cls.objects(project__in=project_ids).update(
|
||||
project=None
|
||||
)
|
||||
res = DeleteProjectResult(disassociated_tasks=updated_count)
|
||||
disassociated = defaultdict(int)
|
||||
for cls in ProjectBLL.child_classes:
|
||||
disassociated[cls] = cls.objects(project__in=project_ids).update(project=None)
|
||||
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
|
||||
else:
|
||||
deleted_models, model_urls = _delete_models(projects=project_ids)
|
||||
deleted_tasks, event_urls, artifact_urls = _delete_tasks(
|
||||
deleted_models, model_event_urls, model_urls = _delete_models(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, projects=project_ids
|
||||
)
|
||||
event_urls = task_event_urls | model_event_urls
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=project_id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=True,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
res = DeleteProjectResult(
|
||||
deleted_tasks=deleted_tasks,
|
||||
deleted_models=deleted_models,
|
||||
@@ -127,9 +150,8 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = {t.id for t in tasks}
|
||||
with TimingContext("mongo", "delete_tasks_update_children"):
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
|
||||
event_urls, artifact_urls = set(), set()
|
||||
for task in tasks:
|
||||
@@ -144,46 +166,58 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
|
||||
}
|
||||
)
|
||||
|
||||
event_bll.delete_multi_task_events(company, list(task_ids))
|
||||
event_bll.delete_multi_task_events(
|
||||
company, list(task_ids), async_delete=async_events_delete
|
||||
)
|
||||
deleted = tasks.delete()
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
|
||||
def _delete_models(
|
||||
company: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set[str], Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
that reference them to reference None.
|
||||
"""
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set()
|
||||
models = Model.objects(project__in=projects).only("task", "id", "uri")
|
||||
if not models:
|
||||
return 0, set(), set()
|
||||
|
||||
model_ids = list({m.id for m in models})
|
||||
model_ids = list({m.id for m in models})
|
||||
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
event_urls, model_urls = set(), set()
|
||||
for m in models:
|
||||
event_urls.update(collect_debug_image_urls(company, m.id))
|
||||
event_urls.update(collect_plot_image_urls(company, m.id))
|
||||
if m.uri:
|
||||
model_urls.add(m.uri)
|
||||
|
||||
urls = {m.uri for m in models if m.uri}
|
||||
deleted = models.delete()
|
||||
return deleted, urls
|
||||
event_bll.delete_multi_task_events(
|
||||
company, model_ids, async_delete=async_events_delete
|
||||
)
|
||||
deleted = models.delete()
|
||||
return deleted, event_urls, model_urls
|
||||
|
||||
@@ -14,14 +14,16 @@ def _get_project_depth(project_name: str) -> int:
|
||||
return len(list(filter(None, project_name.split(name_separator))))
|
||||
|
||||
|
||||
def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
def _validate_project_name(project_name: str, raise_if_empty=True) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove redundant '/' characters. Ensure that the project name is not empty
|
||||
Return the cleaned up project name and location
|
||||
"""
|
||||
name_parts = list(filter(None, project_name.split(name_separator)))
|
||||
name_parts = [p.strip() for p in project_name.split(name_separator) if p]
|
||||
if not name_parts:
|
||||
raise errors.bad_request.InvalidProjectName(name=project_name)
|
||||
if raise_if_empty:
|
||||
raise errors.bad_request.InvalidProjectName(name=project_name)
|
||||
return "", ""
|
||||
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
@@ -34,7 +36,7 @@ def _ensure_project(
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
Return the project
|
||||
"""
|
||||
name = name.strip(name_separator)
|
||||
name, location = _validate_project_name(name, raise_if_empty=False)
|
||||
if not name:
|
||||
return None
|
||||
|
||||
@@ -43,7 +45,6 @@ def _ensure_project(
|
||||
return project
|
||||
|
||||
now = datetime.utcnow()
|
||||
name, location = _validate_project_name(name)
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=user,
|
||||
@@ -101,12 +102,15 @@ def _get_writable_project_from_name(
|
||||
return qs.first()
|
||||
|
||||
|
||||
ProjectsChildren = Mapping[str, Sequence[Project]]
|
||||
|
||||
|
||||
def _get_sub_projects(
|
||||
project_ids: Sequence[str],
|
||||
_only: Sequence[str] = ("id", "path"),
|
||||
search_hidden=True,
|
||||
allowed_ids: Sequence[str] = None,
|
||||
) -> Mapping[str, Sequence[Project]]:
|
||||
) -> ProjectsChildren:
|
||||
"""
|
||||
Return the list of child projects of all the levels for the parent project ids
|
||||
"""
|
||||
@@ -156,13 +160,17 @@ def _update_subproject_names(
|
||||
Optionally update the paths
|
||||
"""
|
||||
updated = 0
|
||||
now = datetime.utcnow()
|
||||
for child in children:
|
||||
child_suffix = name_separator.join(
|
||||
child.name.split(name_separator)[len(old_name.split(name_separator)) :]
|
||||
child.name.split(name_separator)[len(old_name.split(name_separator)):]
|
||||
)
|
||||
updates = {"name": name_separator.join((project.name, child_suffix))}
|
||||
updates = {
|
||||
"name": name_separator.join((project.name, child_suffix)),
|
||||
"last_update": now,
|
||||
}
|
||||
if update_path:
|
||||
updates["path"] = project.path + child.path[len(old_path) :]
|
||||
updates["path"] = project.path + child.path[len(old_path):]
|
||||
updated += child.update(upsert=False, **updates)
|
||||
|
||||
return updated
|
||||
@@ -177,6 +185,7 @@ def _reposition_project_with_children(
|
||||
project.name = name_separator.join(
|
||||
filter(None, (new_location, project.name.split(name_separator)[-1]))
|
||||
)
|
||||
project.last_update = datetime.utcnow()
|
||||
_save_under_parent(project, parent=parent)
|
||||
|
||||
moved = 1 + _update_subproject_names(
|
||||
|
||||
@@ -3,11 +3,13 @@ from datetime import datetime
|
||||
from typing import Callable, Sequence, Optional, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue.queue_metrics import QueueMetrics, MetricsRefresher
|
||||
from apiserver.bll.queue.queue_metrics import QueueMetrics
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -131,7 +133,7 @@ class QueueBLL(object):
|
||||
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return Queue.safe_update(company_id, queue_id, update_fields)
|
||||
|
||||
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
|
||||
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
|
||||
"""
|
||||
Delete the queue
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue is not found
|
||||
@@ -139,16 +141,42 @@ class QueueBLL(object):
|
||||
"""
|
||||
with translate_errors_context():
|
||||
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
|
||||
if queue.entries and not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
if queue.entries:
|
||||
if not force:
|
||||
raise errors.bad_request.QueueNotEmpty(
|
||||
"use force=true to delete", id=queue_id
|
||||
)
|
||||
from apiserver.bll.task import ChangeStatusRequest
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get_for_writing(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=["id", "status", "enqueue_status", "project"],
|
||||
)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason="Queue was deleted",
|
||||
status_message="",
|
||||
user_id=user_id,
|
||||
).execute(enqueue_status=None)
|
||||
except Exception as ex:
|
||||
log.exception(
|
||||
f"Failed dequeuing task {item.task} from queue: {queue_id}"
|
||||
)
|
||||
|
||||
queue.delete()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
@@ -158,16 +186,25 @@ class QueueBLL(object):
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
|
||||
for worker in self.worker_bll.get_all(company_id):
|
||||
if queue_id in worker.queues:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
query: Q = None,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
@@ -179,6 +216,7 @@ class QueueBLL(object):
|
||||
res = Queue.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
override_projection=projection,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
@@ -230,16 +268,22 @@ class QueueBLL(object):
|
||||
|
||||
return res
|
||||
|
||||
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
|
||||
def get_next_task(
|
||||
self, company_id: str, queue_id: str, task_id: str = None
|
||||
) -> Optional[Entry]:
|
||||
"""
|
||||
Atomically pop and return the first task from the queue (or None)
|
||||
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
|
||||
queue = Queue.objects(
|
||||
**query, **({"entries__0__task": task_id} if task_id else {})
|
||||
).modify(pop__entries=-1, upsert=False)
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
if not task_id or not Queue.objects(**query).first():
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
return
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
@@ -334,6 +378,3 @@ class QueueBLL(object):
|
||||
if res is None:
|
||||
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
|
||||
return int(res.get("count"))
|
||||
|
||||
|
||||
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)
|
||||
|
||||
@@ -14,7 +14,6 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.queue import Queue, Entry
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -182,7 +181,7 @@ class QueueMetrics:
|
||||
"aggs": self._get_dates_agg(interval),
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
|
||||
with translate_errors_context():
|
||||
res = self._search_company_metrics(company_id, es_req)
|
||||
|
||||
if "aggregations" not in res:
|
||||
@@ -279,12 +278,17 @@ class MetricsRefresher:
|
||||
|
||||
@classmethod
|
||||
@threads.register("queue_metrics_refresh_watchdog", daemon=True)
|
||||
def start(cls, queue_metrics: QueueMetrics):
|
||||
def start(cls, queue_metrics: QueueMetrics = None):
|
||||
if not cls.watch_interval_sec:
|
||||
return
|
||||
|
||||
if not queue_metrics:
|
||||
from .queue_bll import QueueBLL
|
||||
|
||||
queue_metrics = QueueBLL().metrics
|
||||
|
||||
sleep(10)
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
for queue in Queue.objects():
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -31,20 +30,17 @@ class RedisCacheManager(Generic[T]):
|
||||
|
||||
def set_state(self, state: T) -> None:
|
||||
redis_key = self._get_redis_key(state.id)
|
||||
with TimingContext("redis", "cache_set_state"):
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
self.redis.set(redis_key, state.to_json())
|
||||
self.redis.expire(redis_key, self.expiration_interval)
|
||||
|
||||
def get_state(self, state_id) -> Optional[T]:
|
||||
redis_key = self._get_redis_key(state_id)
|
||||
with TimingContext("redis", "cache_get_state"):
|
||||
response = self.redis.get(redis_key)
|
||||
response = self.redis.get(redis_key)
|
||||
if response:
|
||||
return self.state_class.from_json(response)
|
||||
|
||||
def delete_state(self, state_id) -> None:
|
||||
with TimingContext("redis", "cache_delete_state"):
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
self.redis.delete(self._get_redis_key(state_id))
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from threading import Thread, Lock
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
|
||||
import attr
|
||||
@@ -9,76 +9,83 @@ import psutil
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
|
||||
class ResourceMonitor(Thread):
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
stat_threads = ThreadsManager("Statistics")
|
||||
|
||||
@classmethod
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
def __init__(self, sample_interval_sec=5):
|
||||
super(ResourceMonitor, self).__init__(daemon=True)
|
||||
self.sample_interval_sec = sample_interval_sec
|
||||
self._lock = Lock()
|
||||
self._clear()
|
||||
|
||||
def _clear(self):
|
||||
sample = self._get_sample()
|
||||
self._avg = sample
|
||||
self._min = sample
|
||||
self._max = sample
|
||||
self._clear_time = datetime.utcnow()
|
||||
self._count = 1
|
||||
@attr.s(auto_attribs=True)
|
||||
class Sample:
|
||||
cpu_usage: float = 0.0
|
||||
mem_used_gb: float = 0
|
||||
mem_free_gb: float = 0
|
||||
|
||||
@classmethod
|
||||
def _get_sample(cls) -> Sample:
|
||||
return cls.Sample(
|
||||
def _apply(cls, op, *samples):
|
||||
return cls(
|
||||
**{
|
||||
field: op(*(getattr(sample, field) for sample in samples))
|
||||
for field in attr.fields_dict(cls)
|
||||
}
|
||||
)
|
||||
|
||||
def min(self, sample):
|
||||
return self._apply(min, self, sample)
|
||||
|
||||
def max(self, sample):
|
||||
return self._apply(max, self, sample)
|
||||
|
||||
def avg(self, sample, count):
|
||||
res = self._apply(lambda x: x * count, self)
|
||||
res = self._apply(operator.add, res, sample)
|
||||
res = self._apply(lambda x: x / (count + 1), res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_current_sample(cls) -> "Sample":
|
||||
return cls(
|
||||
cpu_usage=psutil.cpu_percent(),
|
||||
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
|
||||
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
|
||||
)
|
||||
|
||||
def run(self):
|
||||
while not ThreadsManager.terminating:
|
||||
sleep(self.sample_interval_sec)
|
||||
|
||||
sample = self._get_sample()
|
||||
class ResourceMonitor:
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
sample = Sample.get_current_sample()
|
||||
self.avg = sample
|
||||
self.min = sample
|
||||
self.max = sample
|
||||
self.time = datetime.utcnow()
|
||||
self.count = 1
|
||||
|
||||
with self._lock:
|
||||
self._min = self._min.min(sample)
|
||||
self._max = self._max.max(sample)
|
||||
self._avg = self._avg.avg(sample, self._count)
|
||||
self._count += 1
|
||||
def add_sample(self, sample: Sample):
|
||||
self.min = self.min.min(sample)
|
||||
self.max = self.max.max(sample)
|
||||
self.avg = self.avg.avg(sample, self.count)
|
||||
self.count += 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
sample_interval_sec = 5
|
||||
_lock = Lock()
|
||||
accumulator = Accumulator()
|
||||
|
||||
@classmethod
|
||||
@stat_threads.register("resource_monitor", daemon=True)
|
||||
def start(cls):
|
||||
while True:
|
||||
sleep(cls.sample_interval_sec)
|
||||
sample = Sample.get_current_sample()
|
||||
with cls._lock:
|
||||
cls.accumulator.add_sample(sample)
|
||||
|
||||
@classmethod
|
||||
def get_stats(cls) -> dict:
|
||||
""" Returns current resource statistics and clears internal resource statistics """
|
||||
with self._lock:
|
||||
min_ = attr.asdict(self._min)
|
||||
max_ = attr.asdict(self._max)
|
||||
avg = attr.asdict(self._avg)
|
||||
interval = datetime.utcnow() - self._clear_time
|
||||
self._clear()
|
||||
with cls._lock:
|
||||
min_ = attr.asdict(cls.accumulator.min)
|
||||
max_ = attr.asdict(cls.accumulator.max)
|
||||
avg = attr.asdict(cls.accumulator.avg)
|
||||
interval = datetime.utcnow() - cls.accumulator.time
|
||||
cls.accumulator = cls.Accumulator()
|
||||
|
||||
return {
|
||||
"interval_sec": interval.total_seconds(),
|
||||
|
||||
@@ -21,9 +21,8 @@ from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.json import dumps
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
from apiserver.version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor
|
||||
from .resource_monitor import ResourceMonitor, stat_threads
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -31,17 +30,19 @@ worker_bll = WorkerBLL()
|
||||
|
||||
|
||||
class StatisticsReporter:
|
||||
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
|
||||
send_queue = queue.Queue()
|
||||
supported = config.get("apiserver.statistics.supported", True)
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
ResourceMonitor.start()
|
||||
cls.start_sender()
|
||||
cls.start_reporter()
|
||||
|
||||
@classmethod
|
||||
@threads.register("reporter", daemon=True)
|
||||
@stat_threads.register("reporter", daemon=True)
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
@@ -54,7 +55,7 @@ class StatisticsReporter:
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
)
|
||||
sleep(report_interval.total_seconds())
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
for company in Company.objects(
|
||||
defaults__stats_option__enabled=True
|
||||
@@ -68,7 +69,7 @@ class StatisticsReporter:
|
||||
sleep(report_interval.total_seconds())
|
||||
|
||||
@classmethod
|
||||
@threads.register("sender", daemon=True)
|
||||
@stat_threads.register("sender", daemon=True)
|
||||
def start_sender(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
@@ -85,7 +86,7 @@ class StatisticsReporter:
|
||||
|
||||
WarningFilter.attach()
|
||||
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
try:
|
||||
report = cls.send_queue.get()
|
||||
|
||||
@@ -111,7 +112,7 @@ class StatisticsReporter:
|
||||
"uuid": get_server_uuid(),
|
||||
"queues": {"count": Queue.objects(company=company_id).count()},
|
||||
"users": {"count": User.objects(company=company_id).count()},
|
||||
"resources": cls.threads.resource_monitor.get_stats(),
|
||||
"resources": ResourceMonitor.get_stats(),
|
||||
"experiments": next(
|
||||
iter(cls._get_experiments_stats(company_id).values()), {}
|
||||
),
|
||||
|
||||
48
apiserver/bll/storage/__init__.py
Normal file
48
apiserver/bll/storage/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from copy import copy
|
||||
|
||||
from boltons.cacheutils import cachedproperty
|
||||
from clearml.backend_config.bucket_config import (
|
||||
S3BucketConfigurations,
|
||||
AzureContainerConfigurations,
|
||||
GSBucketConfigurations,
|
||||
)
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class StorageBLL:
|
||||
default_aws_configs: S3BucketConfigurations = None
|
||||
conf = config.get("services.storage_credentials")
|
||||
|
||||
@cachedproperty
|
||||
def _default_aws_configs(self) -> S3BucketConfigurations:
|
||||
return S3BucketConfigurations.from_config(self.conf.get("aws.s3"))
|
||||
|
||||
@cachedproperty
|
||||
def _default_azure_configs(self) -> AzureContainerConfigurations:
|
||||
return AzureContainerConfigurations.from_config(self.conf.get("azure.storage"))
|
||||
|
||||
@cachedproperty
|
||||
def _default_gs_configs(self) -> GSBucketConfigurations:
|
||||
return GSBucketConfigurations.from_config(self.conf.get("google.storage"))
|
||||
|
||||
def get_azure_settings_for_company(
|
||||
self,
|
||||
company_id: str,
|
||||
) -> AzureContainerConfigurations:
|
||||
return copy(self._default_azure_configs)
|
||||
|
||||
def get_gs_settings_for_company(
|
||||
self,
|
||||
company_id: str,
|
||||
) -> GSBucketConfigurations:
|
||||
return copy(self._default_gs_configs)
|
||||
|
||||
def get_aws_settings_for_company(
|
||||
self,
|
||||
company_id: str,
|
||||
) -> S3BucketConfigurations:
|
||||
return copy(self._default_aws_configs)
|
||||
@@ -5,7 +5,6 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
@@ -49,49 +48,41 @@ class Artifacts:
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "update_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
force=force,
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
|
||||
@@ -15,7 +15,6 @@ from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@@ -64,78 +63,82 @@ class HyperParams:
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
update_cmds=delete_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{mongoengine_safe(section)}"
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
update_cmds=update_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
@@ -191,57 +194,56 @@ class HyperParams:
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
tasks = Task.aggregate(pipeline)
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
configuration: Sequence[str],
|
||||
force: bool,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_configuration"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
|
||||
@@ -39,7 +39,7 @@ class NonResponsiveTasksWatchdog:
|
||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||
def start(cls):
|
||||
sleep(cls.settings.watch_interval_sec)
|
||||
while not ThreadsManager.terminating:
|
||||
while True:
|
||||
watch_interval = cls.settings.watch_interval_sec
|
||||
if cls.settings.enabled:
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
from typing import Collection, Sequence, Tuple, Optional, Dict
|
||||
|
||||
import six
|
||||
from mongoengine import Q
|
||||
@@ -33,9 +33,7 @@ from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
@@ -66,11 +64,10 @@ class TaskBLL:
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
with TimingContext("mongo", "task_with_access"):
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
@@ -88,15 +85,14 @@ class TaskBLL:
|
||||
only_fields = list(only_fields)
|
||||
only_fields = only_fields + ["status"]
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id=task_id),
|
||||
allow_public=allow_public,
|
||||
override_projection=only_fields,
|
||||
return_dicts=False,
|
||||
)
|
||||
task = None if not tasks else tasks[0]
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
@@ -111,7 +107,7 @@ class TaskBLL:
|
||||
company_id, task_ids, only=None, allow_public=False, return_tasks=True
|
||||
) -> Optional[Sequence[Task]]:
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
with translate_errors_context():
|
||||
ids = set(task_ids)
|
||||
q = Task.get_many(
|
||||
company=company_id,
|
||||
@@ -131,16 +127,16 @@ class TaskBLL:
|
||||
return list(q)
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
identity = call.identity
|
||||
def create(company: str, user: str, fields: dict):
|
||||
now = datetime.utcnow()
|
||||
return Task(
|
||||
id=create_id(),
|
||||
user=identity.user,
|
||||
company=identity.company,
|
||||
user=user,
|
||||
company=company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
**fields,
|
||||
)
|
||||
|
||||
@@ -260,58 +256,66 @@ class TaskBLL:
|
||||
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else None
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
def ensure_int_labels(execution: dict) -> dict:
|
||||
if not execution:
|
||||
return execution
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
model_labels = execution.get("model_labels")
|
||||
if model_labels:
|
||||
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
|
||||
|
||||
return execution
|
||||
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(deleted_prefix)
|
||||
else task.id
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=ensure_int_labels(execution_dict),
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
update_project_time(new_task.project)
|
||||
|
||||
return new_task, new_project_data
|
||||
|
||||
@@ -381,7 +385,7 @@ class TaskBLL:
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
last_events: Dict[str, Dict[str, dict]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
@@ -406,25 +410,90 @@ class TaskBLL:
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_scalar_values is not None:
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
||||
total_metrics = set()
|
||||
if max_values:
|
||||
query = dict(id=task_id)
|
||||
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
||||
if to_add <= max_values:
|
||||
query[f"unique_metrics__{max_values-to_add}__exists"] = True
|
||||
task = Task.objects(**query).only("unique_metrics").first()
|
||||
if task and task.unique_metrics:
|
||||
total_metrics = set(task.unique_metrics)
|
||||
|
||||
def op_path(op, *path):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
new_metrics = []
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
def add_last_metric_conditional_update(
|
||||
metric_path: str, metric_value, iter_value: int, is_min: bool
|
||||
):
|
||||
"""
|
||||
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
|
||||
"""
|
||||
if is_min:
|
||||
field_prefix = "min"
|
||||
op = "$gt"
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
field_prefix = "max"
|
||||
op = "$lt"
|
||||
|
||||
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
|
||||
condition = {
|
||||
"$or": [
|
||||
{"$lte": [f"${value_field}", None]},
|
||||
{op: [f"${value_field}", metric_value]},
|
||||
]
|
||||
}
|
||||
raw_updates[value_field] = {
|
||||
"$cond": [condition, metric_value, f"${value_field}"]
|
||||
}
|
||||
|
||||
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
|
||||
"__", "."
|
||||
)
|
||||
raw_updates[value_iteration_field] = {
|
||||
"$cond": [
|
||||
condition,
|
||||
iter_value,
|
||||
f"${value_iteration_field}",
|
||||
]
|
||||
}
|
||||
|
||||
for metric_key, metric_data in last_scalar_events.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
metric = (
|
||||
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
||||
)
|
||||
if max_values:
|
||||
if (
|
||||
len(total_metrics) >= max_values
|
||||
and metric not in total_metrics
|
||||
):
|
||||
continue
|
||||
total_metrics.add(metric)
|
||||
|
||||
new_metrics.append(metric)
|
||||
path = f"last_metrics__{metric_key}__{variant_key}"
|
||||
for key, value in variant_data.items():
|
||||
if key in ("min_value", "max_value"):
|
||||
add_last_metric_conditional_update(
|
||||
metric_path=path,
|
||||
metric_value=value,
|
||||
iter_value=variant_data.get(f"{key}_iter", 0),
|
||||
is_min=(key == "min_value"),
|
||||
)
|
||||
elif key in ("metric", "variant", "value"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
if new_metrics:
|
||||
extra_updates["add_to_set__unique_metrics"] = new_metrics
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
def events_per_type(metric_data_: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||
return {
|
||||
event_type: EventStats(last_update=event["timestamp"])
|
||||
for event_type, event in metric_data.items()
|
||||
for event_type, event in metric_data_.items()
|
||||
}
|
||||
|
||||
metric_stats = {
|
||||
@@ -435,24 +504,38 @@ class TaskBLL:
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
return TaskBLL.set_last_update(
|
||||
ret = TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
if ret and raw_updates:
|
||||
Task.objects(id=task_id).update_one(__raw__=[{"$set": raw_updates}])
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
cls,
|
||||
task: Task,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
):
|
||||
cls.dequeue(task, company_id)
|
||||
try:
|
||||
cls.dequeue(task, company_id)
|
||||
except errors.bad_request.InvalidQueueOrTaskNotQueued:
|
||||
# dequeue may fail if the queue was deleted
|
||||
pass
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=task.enqueue_status or TaskStatus.created,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
user_id=user_id,
|
||||
).execute(enqueue_status=None)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,19 +1,32 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition
|
||||
from boltons.iterutils import partition, bucketize, first
|
||||
from furl import furl
|
||||
from mongoengine import NotUniqueError
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_bll import PlotFields
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.database.model.url_to_delete import (
|
||||
StorageType,
|
||||
UrlToDelete,
|
||||
FileType,
|
||||
DeletionStatus,
|
||||
)
|
||||
from apiserver.database.utils import id as db_id
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -56,25 +69,24 @@ class CleanupResult:
|
||||
)
|
||||
|
||||
|
||||
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
|
||||
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
urls = set()
|
||||
next_scroll_id = None
|
||||
with TimingContext("es", "collect_plot_image_urls"):
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
@@ -83,9 +95,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
urls = set()
|
||||
while True:
|
||||
res, after_key = event_bll.get_debug_image_urls(
|
||||
company_id=company,
|
||||
task_id=task,
|
||||
after_key=after_key,
|
||||
company_id=company, task_id=task_or_model, after_key=after_key,
|
||||
)
|
||||
urls.update(res)
|
||||
if not after_key:
|
||||
@@ -94,12 +104,87 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
return urls
|
||||
|
||||
|
||||
supported_storage_types = {
|
||||
"https://": StorageType.fileserver,
|
||||
"http://": StorageType.fileserver,
|
||||
"s3://": StorageType.s3,
|
||||
"azure://": StorageType.azure,
|
||||
"gs://": StorageType.gs,
|
||||
}
|
||||
|
||||
|
||||
def _schedule_for_delete(
|
||||
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
|
||||
) -> Set[str]:
|
||||
urls_per_storage = bucketize(
|
||||
urls,
|
||||
key=lambda u: first(
|
||||
type_
|
||||
for prefix, type_ in supported_storage_types.items()
|
||||
if u.startswith(prefix)
|
||||
),
|
||||
)
|
||||
urls_per_storage.pop(None, None)
|
||||
|
||||
processed_urls = set()
|
||||
for storage_type, storage_urls in urls_per_storage.items():
|
||||
delete_folders = (storage_type == StorageType.fileserver) and can_delete_folders
|
||||
scheduled_to_delete = set()
|
||||
for url in storage_urls:
|
||||
folder = None
|
||||
if delete_folders:
|
||||
try:
|
||||
parsed = furl(url)
|
||||
if parsed.path and len(parsed.path.segments) > 1:
|
||||
folder = parsed.remove(
|
||||
args=True, fragment=True, path=parsed.path.segments[-1]
|
||||
).url.rstrip("/")
|
||||
except Exception as ex:
|
||||
pass
|
||||
|
||||
to_delete = folder or url
|
||||
if to_delete in scheduled_to_delete:
|
||||
processed_urls.add(url)
|
||||
continue
|
||||
|
||||
try:
|
||||
UrlToDelete(
|
||||
id=db_id(),
|
||||
company=company,
|
||||
user=user,
|
||||
url=to_delete,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
storage_type=storage_type,
|
||||
type=FileType.folder if folder else FileType.file,
|
||||
).save()
|
||||
except (DuplicateKeyError, NotUniqueError):
|
||||
existing = UrlToDelete.objects(company=company, url=to_delete).first()
|
||||
if existing:
|
||||
existing.update(
|
||||
user=user,
|
||||
task=task_id,
|
||||
created=datetime.utcnow(),
|
||||
retry_count=0,
|
||||
unset__last_failure_time=1,
|
||||
unset__last_failure_reason=1,
|
||||
status=DeletionStatus.created,
|
||||
)
|
||||
processed_urls.add(url)
|
||||
scheduled_to_delete.add(to_delete)
|
||||
|
||||
return processed_urls
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
company: str,
|
||||
user: str,
|
||||
task: Task,
|
||||
force: bool = False,
|
||||
update_children=True,
|
||||
return_file_urls=False,
|
||||
delete_output_models=True,
|
||||
delete_external_artifacts=True,
|
||||
) -> CleanupResult:
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
@@ -110,9 +195,11 @@ def cleanup_task(
|
||||
published_models, draft_models, in_use_model_ids = verify_task_children_and_ouptuts(
|
||||
task, force
|
||||
)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
"services.async_urls_delete.enabled", False
|
||||
)
|
||||
event_urls, artifact_urls, model_urls = set(), set(), set()
|
||||
if return_file_urls:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls = collect_debug_image_urls(task.company, task.id)
|
||||
event_urls.update(collect_plot_image_urls(task.company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
@@ -128,10 +215,7 @@ def cleanup_task(
|
||||
deleted_task_id = f"{deleted_prefix}{task.id}"
|
||||
updated_children = 0
|
||||
if update_children:
|
||||
with TimingContext("mongo", "update_task_children"):
|
||||
updated_children = Task.objects(parent=task.id).update(
|
||||
parent=deleted_task_id
|
||||
)
|
||||
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
|
||||
|
||||
deleted_models = 0
|
||||
updated_models = 0
|
||||
@@ -139,9 +223,23 @@ def cleanup_task(
|
||||
if not models:
|
||||
continue
|
||||
if delete_output_models and allow_delete:
|
||||
deleted_models += Model.objects(
|
||||
id__in=[m.id for m in models if m.id not in in_use_model_ids]
|
||||
).delete()
|
||||
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
|
||||
for m_id in model_ids:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls.update(collect_debug_image_urls(task.company, m_id))
|
||||
event_urls.update(collect_plot_image_urls(task.company, m_id))
|
||||
try:
|
||||
event_bll.delete_task_events(
|
||||
task.company,
|
||||
m_id,
|
||||
allow_locked=True,
|
||||
model=True,
|
||||
async_delete=async_events_delete,
|
||||
)
|
||||
except errors.bad_request.InvalidModelId as ex:
|
||||
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
|
||||
|
||||
deleted_models += Model.objects(id__in=list(model_ids)).delete()
|
||||
if in_use_model_ids:
|
||||
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
|
||||
continue
|
||||
@@ -153,7 +251,20 @@ def cleanup_task(
|
||||
else:
|
||||
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
event_bll.delete_task_events(
|
||||
task.company, task.id, allow_locked=force, async_delete=async_events_delete
|
||||
)
|
||||
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
task_id=task.id,
|
||||
company=company,
|
||||
user=user,
|
||||
urls=event_urls | model_urls | artifact_urls,
|
||||
can_delete_folders=not in_use_model_ids and not published_models,
|
||||
)
|
||||
for urls in (event_urls, model_urls, artifact_urls):
|
||||
urls.difference_update(scheduled)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
@@ -173,16 +284,15 @@ def verify_task_children_and_ouptuts(
|
||||
task, force: bool
|
||||
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
|
||||
if not force:
|
||||
with TimingContext("mongo", "count_published_children"):
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
model_fields = ["id", "ready", "uri"]
|
||||
published_models, draft_models = partition(
|
||||
|
||||
@@ -30,7 +30,11 @@ queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def archive_task(
|
||||
task: Union[str, Task], company_id: str, status_message: str, status_reason: str,
|
||||
task: Union[str, Task],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Deque and archive task
|
||||
@@ -52,7 +56,11 @@ def archive_task(
|
||||
)
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task, company_id, status_message, status_reason,
|
||||
task,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
@@ -63,11 +71,12 @@ def archive_task(
|
||||
status_reason=status_reason,
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task: str, company_id: str, status_message: str, status_reason: str,
|
||||
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
@@ -80,11 +89,16 @@ def unarchive_task(
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str, company_id: str, status_message: str, status_reason: str,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> Tuple[int, dict]:
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
@@ -92,7 +106,11 @@ def dequeue_task(
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task, company_id, status_message=status_message, status_reason=status_reason,
|
||||
task,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
return 1, res
|
||||
|
||||
@@ -100,6 +118,7 @@ def dequeue_task(
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
@@ -139,6 +158,7 @@ def enqueue_task(
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
@@ -151,6 +171,7 @@ def enqueue_task(
|
||||
new_status=task.status,
|
||||
force=True,
|
||||
status_reason="failed enqueueing",
|
||||
user_id=user_id,
|
||||
).execute(enqueue_status=None)
|
||||
raise
|
||||
|
||||
@@ -191,12 +212,14 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
@@ -218,6 +241,7 @@ def delete_task(
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
@@ -226,10 +250,13 @@ def delete_task(
|
||||
pass
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
task,
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
if move_to_trash:
|
||||
@@ -246,10 +273,12 @@ def delete_task(
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
@@ -268,16 +297,20 @@ def reset_task(
|
||||
pass
|
||||
|
||||
cleaned_up = cleanup_task(
|
||||
task,
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
task=task,
|
||||
force=force,
|
||||
update_children=False,
|
||||
return_file_urls=return_file_urls,
|
||||
delete_output_models=delete_output_models,
|
||||
delete_external_artifacts=delete_external_artifacts,
|
||||
)
|
||||
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__unique_metrics=[],
|
||||
set__metric_stats={},
|
||||
set__models__output=[],
|
||||
set__runtime={},
|
||||
@@ -308,6 +341,7 @@ def reset_task(
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
user_id=user_id,
|
||||
).execute(
|
||||
started=None,
|
||||
completed=None,
|
||||
@@ -323,8 +357,9 @@ def reset_task(
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str], Any] = None,
|
||||
publish_model_func: Callable[[str, str, str], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
@@ -352,7 +387,7 @@ def publish_task(
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id)
|
||||
publish_model_func(model.id, company_id, user_id)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
@@ -361,6 +396,7 @@ def publish_task(
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
user_id=user_id,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
@@ -373,7 +409,12 @@ def publish_task(
|
||||
|
||||
|
||||
def stop_task(
|
||||
task_id: str, company_id: str, user_name: str, status_reason: str, force: bool,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
@@ -435,4 +476,5 @@ def stop_task(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
).execute()
|
||||
|
||||
@@ -9,7 +9,6 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
@@ -27,6 +26,7 @@ class ChangeStatusRequest(object):
|
||||
force = attr.ib(type=bool, default=False)
|
||||
allow_same_state_transition = attr.ib(type=bool, default=True)
|
||||
current_status_override = attr.ib(default=None)
|
||||
user_id = attr.ib(type=str, default=None)
|
||||
|
||||
def execute(self, **kwargs):
|
||||
current_status = self.current_status_override or self.task.status
|
||||
@@ -45,6 +45,7 @@ class ChangeStatusRequest(object):
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=self.user_id,
|
||||
)
|
||||
|
||||
if self.new_status == TaskStatus.queued:
|
||||
@@ -55,7 +56,7 @@ class ChangeStatusRequest(object):
|
||||
|
||||
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
||||
with translate_errors_context():
|
||||
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
||||
params = fields.copy()
|
||||
params.update(control)
|
||||
@@ -166,7 +167,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
@@ -188,9 +189,9 @@ def get_task_for_update(
|
||||
return task
|
||||
|
||||
|
||||
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
|
||||
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
|
||||
now = datetime.utcnow()
|
||||
last_updates = dict(last_change=now)
|
||||
last_updates = dict(last_change=now, last_changed_by=user_id)
|
||||
if set_last_update:
|
||||
last_updates.update(last_update=now)
|
||||
return task.update(**update_cmds, **last_updates)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.users import CreateRequest
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -12,7 +14,7 @@ class UserBLL:
|
||||
if user_id and User.objects(id=user_id).only("id"):
|
||||
raise errors.bad_request.UserIdExists(id=user_id)
|
||||
|
||||
user = User(**request.to_struct())
|
||||
user = User(**request.to_struct(), created=datetime.utcnow())
|
||||
user.save(force_insert=True)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
from boltons.iterutils import partition
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
@@ -25,7 +27,6 @@ from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
@@ -51,6 +52,7 @@ class WorkerBLL:
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@@ -95,9 +97,10 @@ class WorkerBLL:
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
self._save_worker_data(entry)
|
||||
|
||||
return entry
|
||||
|
||||
@@ -109,15 +112,20 @@ class WorkerBLL:
|
||||
:param worker: worker ID
|
||||
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
|
||||
"""
|
||||
with TimingContext("redis", "workers_unregister"):
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res and not config.get("apiserver.workers.auto_unregister", False):
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
ip: str,
|
||||
report: StatusReportRequest,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
@@ -138,6 +146,8 @@ class WorkerBLL:
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
if system_tags is not None:
|
||||
entry.system_tags = system_tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
@@ -176,7 +186,9 @@ class WorkerBLL:
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
entry.project = IdNameEntry(
|
||||
id=project.id, name=project.name
|
||||
)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
@@ -193,6 +205,7 @@ class WorkerBLL:
|
||||
company_id: str,
|
||||
last_seen: Optional[int] = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""
|
||||
Get all the company workers that were active during the last_seen period
|
||||
@@ -201,7 +214,7 @@ class WorkerBLL:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workers = self._get(company_id)
|
||||
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
|
||||
except Exception as e:
|
||||
raise server_error.DataError("failed loading worker entries", err=e.args[0])
|
||||
|
||||
@@ -213,26 +226,25 @@ class WorkerBLL:
|
||||
if w.last_activity_time.replace(tzinfo=None) >= ref_time
|
||||
]
|
||||
|
||||
if tags:
|
||||
include = {t for t in tags if not t.startswith("-")}
|
||||
exclude = {t[1:] for t in tags if t.startswith("-")}
|
||||
workers = [
|
||||
w
|
||||
for w in workers
|
||||
if (not include or any(t in include for t in w.tags))
|
||||
and (not exclude or all(t not in exclude for t in w.tags))
|
||||
]
|
||||
|
||||
return workers
|
||||
|
||||
def get_all_with_projection(
|
||||
self, company_id: str, last_seen: int, tags: Sequence[str] = None
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: int,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags),
|
||||
self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -323,8 +335,7 @@ class WorkerBLL:
|
||||
"""
|
||||
key = self._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
with TimingContext("redis", "get_worker"):
|
||||
data = self.redis.get(key)
|
||||
data = self.redis.get(key)
|
||||
|
||||
if data:
|
||||
try:
|
||||
@@ -351,27 +362,117 @@ class WorkerBLL:
|
||||
|
||||
raise bad_request.InvalidWorkerId(worker=worker)
|
||||
|
||||
@staticmethod
|
||||
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers.{tags_field}_{company}_{tag}"
|
||||
|
||||
@staticmethod
|
||||
def _get_all_workers_key(company: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers_{company}"
|
||||
|
||||
def _save_worker_data(self, entry: WorkerEntry):
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
company_id = entry.company.id
|
||||
expiration = int(time()) + entry.register_timeout
|
||||
worker_item = {entry.key: expiration}
|
||||
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
|
||||
for tags, tags_field in (
|
||||
(entry.tags, "tags"),
|
||||
(entry.system_tags, "systemtags"),
|
||||
):
|
||||
for tag in tags:
|
||||
name = self._get_tagged_workers_key(company_id, tags_field, tag)
|
||||
self.redis.zadd(name, worker_item)
|
||||
|
||||
def _save_worker(self, entry: WorkerEntry) -> None:
|
||||
"""Save worker entry in Redis"""
|
||||
try:
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
self._save_worker_data(entry)
|
||||
except Exception:
|
||||
msg = "Failed saving worker entry"
|
||||
log.exception(msg)
|
||||
|
||||
def _get(
|
||||
self, company: str, user: str = "*", worker_id: str = "*"
|
||||
self,
|
||||
company: str,
|
||||
user: str = "*",
|
||||
worker_id: str = "*",
|
||||
user_tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
|
||||
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
||||
if user == "*":
|
||||
return in_keys
|
||||
user_bytes = user.encode()
|
||||
return {k for k in in_keys if user_bytes in k}
|
||||
|
||||
if user_tags or system_tags:
|
||||
worker_keys = set()
|
||||
for tags, tags_field in (
|
||||
(user_tags, "tags"),
|
||||
(system_tags, "systemtags"),
|
||||
):
|
||||
if not tags:
|
||||
continue
|
||||
timestamp = int(time())
|
||||
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
||||
if include:
|
||||
tagged_workers = set()
|
||||
for tag in include:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
tagged_workers.update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
tagged_workers = filter_by_user(tagged_workers)
|
||||
worker_keys = (
|
||||
worker_keys.intersection(tagged_workers)
|
||||
if worker_keys
|
||||
else tagged_workers
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
if exclude:
|
||||
if not worker_keys:
|
||||
all_workers_key = self._get_all_workers_key(company)
|
||||
self.redis.zremrangebyscore(
|
||||
all_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
||||
worker_keys = filter_by_user(worker_keys)
|
||||
if not worker_keys:
|
||||
return []
|
||||
for tag in exclude:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag[1:]
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.difference_update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
else:
|
||||
match = self._get_worker_key(company, user, "*")
|
||||
worker_keys = self.redis.scan_iter(match)
|
||||
|
||||
entries = []
|
||||
match = self._get_worker_key(company, user, worker_id)
|
||||
with TimingContext("redis", "workers_get_all"):
|
||||
for r in self.redis.scan_iter(match):
|
||||
data = self.redis.get(r)
|
||||
if data:
|
||||
entries.append(WorkerEntry.from_json(data))
|
||||
for key in worker_keys:
|
||||
data = self.redis.get(key)
|
||||
if data:
|
||||
entries.append(WorkerEntry.from_json(data))
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -126,7 +125,7 @@ class WorkerStats:
|
||||
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
|
||||
es_req["query"] = {"bool": {"must": query_terms}}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
return self._extract_results(data, request.items, request.split_by_variant)
|
||||
@@ -223,9 +222,7 @@ class WorkerStats:
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_worker_activity_report"
|
||||
):
|
||||
with translate_errors_context():
|
||||
data = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if "aggregations" not in data:
|
||||
|
||||
@@ -41,10 +41,6 @@
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
strict: false
|
||||
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
elastic {
|
||||
@@ -117,6 +113,10 @@
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
|
||||
# Timeout in seconds for worker registration (or status report). If a worker did not report for this long,
|
||||
# it is discarded from the server's table
|
||||
default_timeout: 600
|
||||
}
|
||||
|
||||
check_for_updates {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
fileserver = "http://localhost:8081"
|
||||
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
|
||||
@@ -2,3 +2,8 @@ max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
|
||||
allow_disk_use {
|
||||
sort: true
|
||||
aggregate: true
|
||||
}
|
||||
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
12
apiserver/config/default/services/async_urls_delete.conf
Normal file
@@ -0,0 +1,12 @@
|
||||
# if set to True then on task delete/reset external file urls for know storage types are scheduled for async delete
|
||||
# otherwise they are returned to a client for the client side delete
|
||||
enabled: true
|
||||
max_retries: 3
|
||||
retry_timeout_sec: 60
|
||||
|
||||
fileserver {
|
||||
# fileserver url prefixes. Evaluated in the order of priority
|
||||
# Can be in the form <schema>://host:port/path or /path
|
||||
url_prefixes: ["https://files.community-master.hosted.allegro.ai/"]
|
||||
timeout_sec: 300
|
||||
}
|
||||
@@ -39,4 +39,7 @@ events_retrieval {
|
||||
validate_plot_str: false
|
||||
|
||||
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||
plot_compression_threshold: 100000
|
||||
plot_compression_threshold: 100000
|
||||
|
||||
# async events delete threshold
|
||||
max_async_deleted_events_per_sec: 1000
|
||||
@@ -2,4 +2,7 @@
|
||||
metrics_before_from_date: 3600
|
||||
# interval in seconds to update queue metrics. Put 0 to disable
|
||||
metrics_refresh_interval_sec: 300
|
||||
# the queues with these tags will not be returned from get_all/get_all_ex unless id or name specified
|
||||
# or search_hidden is set
|
||||
hidden_tags: [k8s-glue]
|
||||
}
|
||||
53
apiserver/config/default/services/storage_credentials.conf
Normal file
53
apiserver/config/default/services/storage_credentials.conf
Normal file
@@ -0,0 +1,53 @@
|
||||
aws {
|
||||
s3 {
|
||||
# S3 credentials, used for read/write access by various SDK elements
|
||||
# default, used for any bucket not specified below
|
||||
key: ""
|
||||
secret: ""
|
||||
region: ""
|
||||
use_credentials_chain: false
|
||||
# Additional ExtraArgs passed to boto3 when uploading files. Can also be set per-bucket under "credentials".
|
||||
extra_args: {}
|
||||
credentials: [
|
||||
# specifies key/secret credentials to use when handling s3 urls (read or write)
|
||||
# {
|
||||
# bucket: "my-bucket-name"
|
||||
# key: "my-access-key"
|
||||
# secret: "my-secret-key"
|
||||
# },
|
||||
{
|
||||
# This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
host: "localhost:9000"
|
||||
key: "evg_user"
|
||||
secret: "evg_pass"
|
||||
multipart: false
|
||||
secure: false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
google.storage {
|
||||
# Default project and credentials file
|
||||
# Will be used when no bucket configuration is found
|
||||
// project: "clearml"
|
||||
// credentials_json: "/path/to/credentials.json"
|
||||
//
|
||||
// # Specific credentials per bucket and sub directory
|
||||
// credentials = [
|
||||
// {
|
||||
// bucket: "my-bucket"
|
||||
// subdir: "path/in/bucket" # Not required
|
||||
// project: "clearml"
|
||||
// credentials_json: "/path/to/credentials.json"
|
||||
// },
|
||||
// ]
|
||||
}
|
||||
azure.storage {
|
||||
# containers: [
|
||||
# {
|
||||
# account_name: "clearml"
|
||||
# account_key: "secret"
|
||||
# # container_name:
|
||||
# }
|
||||
# ]
|
||||
}
|
||||
@@ -19,4 +19,11 @@ hyperparam_values {
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
}
|
||||
|
||||
# the maximum amount of unique last metrics/variants combinations
|
||||
# for which the last values are stored in a task
|
||||
max_last_metrics: 2000
|
||||
|
||||
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
|
||||
async_events_delete: false
|
||||
|
||||
@@ -29,6 +29,9 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
)
|
||||
|
||||
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
|
||||
OVERRIDE_USERNAME_ENV_KEY = "CLEARML_MONGODB_SERVICE_USERNAME"
|
||||
OVERRIDE_PASSWORD_ENV_KEY = "CLEARML_MONGODB_SERVICE_PASSWORD"
|
||||
OVERRIDE_QUERY_ENV_KEY = "CLEARML_MONGODB_SERVICE_QUERY"
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
@@ -52,14 +55,23 @@ class DatabaseFactory:
|
||||
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
override_username = getenv(OVERRIDE_USERNAME_ENV_KEY)
|
||||
override_password = getenv(OVERRIDE_PASSWORD_ENV_KEY)
|
||||
override_query = getenv(OVERRIDE_QUERY_ENV_KEY)
|
||||
|
||||
if override_connection_string:
|
||||
log.info(f"Using override mongodb connection string {override_connection_string}")
|
||||
log.info(f"Using override mongodb connection string template {override_connection_string}")
|
||||
else:
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
if override_username:
|
||||
log.info(f"Using override mongodb username {override_username}")
|
||||
if override_password:
|
||||
log.info(f"Using override mongodb password ******")
|
||||
if override_query:
|
||||
log.info(f"Using override mongodb query {override_query}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
@@ -69,12 +81,20 @@ class DatabaseFactory:
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_connection_string:
|
||||
entry.host = override_connection_string
|
||||
con_str = f"{override_connection_string.rstrip('/')}/{key}"
|
||||
log.info(f"Using override mongodb connection string for {alias}: {con_str}")
|
||||
entry.host = con_str
|
||||
else:
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_username:
|
||||
entry.host = furl(entry.host).set(username=override_username).url
|
||||
if override_password:
|
||||
entry.host = furl(entry.host).set(password=override_password).url
|
||||
if override_query:
|
||||
entry.host = furl(entry.host).set(query=override_query).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
|
||||
@@ -17,16 +17,16 @@ from typing import (
|
||||
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.projection import ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
|
||||
from apiserver.database.utils import (
|
||||
@@ -36,9 +36,10 @@ from apiserver.database.utils import (
|
||||
field_exists,
|
||||
)
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
mongo_conf = config.get("services._mongo")
|
||||
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
||||
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
||||
|
||||
@@ -55,17 +56,25 @@ class ProperDictMixin(object):
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
exclude=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
only=only,
|
||||
extra_dict=extra_dict,
|
||||
exclude=exclude,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def properize_dict(
|
||||
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
|
||||
cls,
|
||||
d,
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
exclude=None,
|
||||
normalize_id=True,
|
||||
):
|
||||
res = d
|
||||
if normalize_id and "_id" in res:
|
||||
@@ -76,6 +85,9 @@ class ProperDictMixin(object):
|
||||
res = project_dict(res, only)
|
||||
if extra_dict:
|
||||
res.update(extra_dict)
|
||||
if exclude:
|
||||
exclude_fields_from_dict(res, exclude)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -136,20 +148,30 @@ class GetMixin(PropsMixin):
|
||||
"or": (default_mongo_op, True),
|
||||
}
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
def __init__(self, field, legacy=False):
|
||||
self._field = field
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
self._support_legacy = legacy
|
||||
self.allow_empty = False
|
||||
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
op = (
|
||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
return op
|
||||
try:
|
||||
op = (
|
||||
v[len(self.op_prefix) :]
|
||||
if v and v.startswith(self.op_prefix)
|
||||
else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
return op
|
||||
except AttributeError:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"invalid value type, string expected",
|
||||
field=self._field,
|
||||
value=str(v),
|
||||
)
|
||||
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
if v is None:
|
||||
@@ -215,8 +237,8 @@ class GetMixin(PropsMixin):
|
||||
cls._cache_manager = RedisCacheManager(
|
||||
state_class=cls.GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
expiration_interval=mongo_conf.get(
|
||||
"scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
@@ -335,84 +357,108 @@ class GetMixin(PropsMixin):
|
||||
parameters_options = parameters_options or cls.get_all_query_options
|
||||
dict_query = {}
|
||||
query = RegexQ()
|
||||
if parameters:
|
||||
parameters = {
|
||||
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
|
||||
}
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
field = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if parameters:
|
||||
parameters = {
|
||||
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
|
||||
}
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.list_fields, parameters=parameters
|
||||
).items():
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.list_fields, parameters=parameters
|
||||
).items():
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.range_fields, parameters=parameters
|
||||
).items():
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.range_fields, parameters=parameters
|
||||
).items():
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= RegexQ(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= RegexQ(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
# date time fields also support simplified range queries. Check if this is the case
|
||||
if len(data) == 2 and not any(
|
||||
d.startswith(mod)
|
||||
for d in data
|
||||
if d is not None
|
||||
for mod in ACCESS_MODIFIER
|
||||
):
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
else:
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = (
|
||||
field
|
||||
if not modifier
|
||||
else "__".join((field, modifier))
|
||||
)
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
if field not in keys:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = field if not modifier else "__".join((field, modifier))
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
if field not in keys:
|
||||
continue
|
||||
try:
|
||||
data = cls.MultiFieldParameters(**value)
|
||||
except Exception:
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
data = cls.MultiFieldParameters(**value)
|
||||
except Exception:
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
),
|
||||
),
|
||||
),
|
||||
data.fields,
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
query = query & q
|
||||
data.fields,
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
query = query & q
|
||||
except APIError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"failed parsing query field",
|
||||
error=str(ex),
|
||||
**({"field": field} if field else {}),
|
||||
)
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@@ -461,7 +507,7 @@ class GetMixin(PropsMixin):
|
||||
if not isinstance(data, (list, tuple)):
|
||||
data = [data]
|
||||
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
helper = cls.ListFieldBucketHelper(field, legacy=True)
|
||||
global_op = helper.get_global_op(data)
|
||||
actions = helper.get_actions(data)
|
||||
|
||||
@@ -497,6 +543,8 @@ class GetMixin(PropsMixin):
|
||||
def validate_order_by(cls, parameters, search_text) -> Sequence:
|
||||
"""
|
||||
Validate and extract order_by params as a list
|
||||
If ordering is specified then make sure that id field is part of it
|
||||
This guarantees the unique order when paging
|
||||
"""
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if not order_by:
|
||||
@@ -509,6 +557,9 @@ class GetMixin(PropsMixin):
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
|
||||
if not any(id_field in order_by for id_field in ("id", "-id")):
|
||||
order_by.append("id")
|
||||
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
@@ -525,7 +576,7 @@ class GetMixin(PropsMixin):
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
max_page_size = mongo_conf.get("max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
@@ -835,6 +886,13 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@staticmethod
|
||||
def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence):
|
||||
disk_use_setting = mongo_conf.get("allow_disk_use.sort", None)
|
||||
if disk_use_setting is not None:
|
||||
qs = qs.allow_disk_use(disk_use_setting)
|
||||
return qs.order_by(*order_by)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
@@ -876,7 +934,7 @@ class GetMixin(PropsMixin):
|
||||
qs = qs.search_text(search_text)
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
qs = cls._get_qs_with_ordering(qs, order_by)
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
@@ -968,7 +1026,7 @@ class GetMixin(PropsMixin):
|
||||
res = cls._get_queries_for_order_field(query, order_field)
|
||||
if res:
|
||||
query_sets = [cls.objects(q) for q in res]
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
query_sets = [cls._get_qs_with_ordering(qs, order_by) for qs in query_sets]
|
||||
if order_field and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_field)
|
||||
|
||||
@@ -991,7 +1049,11 @@ class GetMixin(PropsMixin):
|
||||
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
|
||||
|
||||
if start is None or not size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
return [
|
||||
obj.to_proper_dict(only=include, exclude=exclude)
|
||||
for qs in query_sets
|
||||
for obj in qs
|
||||
]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
@@ -999,7 +1061,7 @@ class GetMixin(PropsMixin):
|
||||
for i, qs in enumerate(query_sets):
|
||||
last_size = len(ret)
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
obj.to_proper_dict(only=include, exclude=exclude)
|
||||
for obj in (qs.skip(start) if start else qs).limit(size)
|
||||
)
|
||||
added = len(ret) - last_size
|
||||
@@ -1124,7 +1186,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
kwargs.update(
|
||||
allowDiskUse=allow_disk_use
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
else mongo_conf.get("allow_disk_use.aggregate", True)
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ class Model(AttributedDocument):
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "uri"),
|
||||
{
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
@@ -91,3 +92,6 @@ class Model(AttributedDocument):
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
return self.company or self.company_origin or ""
|
||||
|
||||
@@ -4,6 +4,7 @@ from mongoengine import (
|
||||
DynamicField,
|
||||
LongField,
|
||||
EmbeddedDocumentField,
|
||||
IntField,
|
||||
)
|
||||
|
||||
from apiserver.database.fields import SafeMapField
|
||||
@@ -19,7 +20,9 @@ class MetricEvent(EmbeddedDocument):
|
||||
variant = StringField(required=True)
|
||||
value = DynamicField(required=True)
|
||||
min_value = DynamicField() # for backwards compatibility reasons
|
||||
min_value_iteration = IntField()
|
||||
max_value = DynamicField() # for backwards compatibility reasons
|
||||
max_value_iteration = IntField()
|
||||
|
||||
|
||||
class EventStats(EmbeddedDocument):
|
||||
|
||||
@@ -19,6 +19,7 @@ from apiserver.database.fields import (
|
||||
SafeSortedListField,
|
||||
EmbeddedDocumentListField,
|
||||
NullableStringField,
|
||||
NoneType,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
@@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument):
|
||||
content_size = LongField()
|
||||
timestamp = LongField()
|
||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
display_data = SafeSortedListField(
|
||||
ListField(UnionField((int, float, str, NoneType)))
|
||||
)
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
@@ -149,6 +152,7 @@ class TaskType(object):
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
report = "report"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
@@ -195,6 +199,7 @@ class Task(AttributedDocument):
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$report",
|
||||
"$models.input.model",
|
||||
"$models.output.model",
|
||||
"$script.repository",
|
||||
@@ -205,6 +210,7 @@ class Task(AttributedDocument):
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"report": 10,
|
||||
"models.output.model": 2,
|
||||
"models.input.model": 2,
|
||||
"script.repository": 1,
|
||||
@@ -227,7 +233,8 @@ class Task(AttributedDocument):
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
pattern_fields=("name", "comment"),
|
||||
pattern_fields=("name", "comment", "report"),
|
||||
fields=("execution.queue", "runtime.*", "models.input.model"),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -241,6 +248,8 @@ class Task(AttributedDocument):
|
||||
status_message = StringField(user_set_allowed=True)
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
report = StringField()
|
||||
report_assets = ListField(StringField())
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
started = DateTimeField()
|
||||
completed = DateTimeField()
|
||||
@@ -259,6 +268,7 @@ class Task(AttributedDocument):
|
||||
last_change = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
@@ -270,6 +280,7 @@ class Task(AttributedDocument):
|
||||
enqueue_status = StringField(
|
||||
choices=get_options(TaskStatus), exclude_by_default=True
|
||||
)
|
||||
last_changed_by = StringField()
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
"""
|
||||
|
||||
52
apiserver/database/model/url_to_delete.py
Normal file
52
apiserver/database/model/url_to_delete.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import StringField, DateTimeField, IntField, EnumField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import AttributedDocument
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
fileserver = "fileserver"
|
||||
s3 = "s3"
|
||||
azure = "azure"
|
||||
gs = "gs"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
file = "file"
|
||||
folder = "folder"
|
||||
|
||||
|
||||
class DeletionStatus(str, Enum):
|
||||
created = "created"
|
||||
retrying = "retrying"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class UrlToDelete(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"url": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "user", "task"),
|
||||
("company", "storage_type", "url"),
|
||||
("status", "retry_count", "storage_type"),
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
url = StringField(required=True, unique_with="company")
|
||||
task = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
storage_type = EnumField(StorageType, default=StorageType.unknown)
|
||||
type = EnumField(FileType, default=FileType.file)
|
||||
retry_count = IntField(default=0)
|
||||
last_failure_time = DateTimeField()
|
||||
last_failure_reason = StringField()
|
||||
status = EnumField(DeletionStatus, default=DeletionStatus.created)
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine import Document, StringField, DynamicField
|
||||
from mongoengine import Document, StringField, DynamicField, DateTimeField
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.model import DbModelMixin
|
||||
@@ -20,3 +20,4 @@ class User(DbModelMixin, Document):
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = DynamicField(default="", exclude_by_default=True)
|
||||
created = DateTimeField()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
from typing import Sequence, Dict, Callable
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.props import PropsMixin
|
||||
@@ -9,65 +9,6 @@ from apiserver.database.props import PropsMixin
|
||||
SEP = "."
|
||||
|
||||
|
||||
def project_dict(data, projection, separator=SEP):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
:param projection: List of dictionary paths (each a string with field names separated using a separator)
|
||||
:param separator: Separator (default is '.')
|
||||
:return: A new dictionary containing only the projected parts from the original dictionary
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
result = {}
|
||||
|
||||
def copy_path(path_parts, source, destination):
|
||||
src, dst = source, destination
|
||||
try:
|
||||
for depth, path_part in enumerate(path_parts[:-1]):
|
||||
src_part = src[path_part]
|
||||
if isinstance(src_part, dict):
|
||||
src = src_part
|
||||
dst = dst.setdefault(path_part, {})
|
||||
elif isinstance(src_part, (list, tuple)):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError(
|
||||
"Incompatible destination type %s for %s (list expected)"
|
||||
% (type(dst), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError(
|
||||
"Destination list length differs from source length for %s"
|
||||
% separator.join(path_parts[: depth + 1])
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported projection type %s for %s"
|
||||
% (type(src), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
except KeyError:
|
||||
# Projection field not in source, no biggie.
|
||||
pass
|
||||
return destination
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator), source=data, destination=result
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _ReferenceProxy(dict):
|
||||
def __init__(self, id):
|
||||
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
|
||||
@@ -108,9 +49,6 @@ class ProjectionHelper(object):
|
||||
self._ref_projection = None
|
||||
self._proxy_manager = _ProxyManager()
|
||||
|
||||
# Cached dpath paths for each of the result documents
|
||||
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
|
||||
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any, Mapping
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
@@ -203,18 +203,22 @@ def _names_set(*names: str) -> Set[str]:
|
||||
return set(names) | set(f"-{name}" for name in names)
|
||||
|
||||
|
||||
system_tag_names = {
|
||||
_system_tag_names = {
|
||||
"model": _names_set("active", "archived"),
|
||||
"project": _names_set("archived", "public", "default"),
|
||||
"task": _names_set("active", "archived", "development"),
|
||||
"queue": _names_set("default"),
|
||||
}
|
||||
|
||||
system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
_system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
|
||||
|
||||
def partition_tags(
|
||||
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
|
||||
entity: str,
|
||||
tags: Sequence[str],
|
||||
system_tags: Optional[Sequence[str]] = (),
|
||||
system_tag_names: Mapping = _system_tag_names,
|
||||
system_tag_prefixes: Mapping = _system_tag_prefixes,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Partition the given tags sequence into system and user-defined tags
|
||||
|
||||
@@ -35,6 +35,12 @@
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"model_event": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ class MissingPasswordForElasticUser(Exception):
|
||||
|
||||
class ESFactory:
|
||||
@classmethod
|
||||
def connect(cls, cluster_name):
|
||||
def connect(cls, cluster_name) -> Elasticsearch:
|
||||
"""
|
||||
Returns the es client for the cluster.
|
||||
Connects to the cluster if did not connect previously
|
||||
|
||||
611
apiserver/jobs/async_urls_delete.py
Normal file
611
apiserver/jobs/async_urls_delete.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import os
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional, Tuple, Mapping, TypeVar, Hashable, Generic
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from azure.storage.blob import ContainerClient, PartialBatchErrorException
|
||||
from boltons.iterutils import bucketize, chunked_iter
|
||||
from furl import furl
|
||||
from google.cloud import storage as google_storage
|
||||
from mongoengine import Q
|
||||
from mypy_boto3_s3.service_resource import Bucket as AWSBucket
|
||||
|
||||
from apiserver.bll.storage import StorageBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import db
|
||||
from apiserver.database.model.url_to_delete import UrlToDelete, StorageType, DeletionStatus
|
||||
|
||||
log = config.logger(f"JOB-{Path(__file__).name}")
|
||||
conf = config.get("services.async_urls_delete")
|
||||
max_retries = conf.get("max_retries", 3)
|
||||
retry_timeout = timedelta(seconds=conf.get("retry_timeout_sec", 60))
|
||||
storage_bll = StorageBLL()
|
||||
|
||||
|
||||
def mark_retry_failed(ids: Sequence[str], reason: str):
|
||||
UrlToDelete.objects(id__in=ids).update(
|
||||
last_failure_time=datetime.utcnow(),
|
||||
last_failure_reason=reason,
|
||||
inc__retry_count=1,
|
||||
)
|
||||
UrlToDelete.objects(id__in=ids, retry_count__gte=max_retries).update(
|
||||
status=DeletionStatus.failed
|
||||
)
|
||||
|
||||
|
||||
def mark_failed(query: Q, reason: str):
|
||||
UrlToDelete.objects(query).update(
|
||||
status=DeletionStatus.failed,
|
||||
last_failure_time=datetime.utcnow(),
|
||||
last_failure_reason=reason,
|
||||
)
|
||||
|
||||
|
||||
def scheme_prefix(scheme: str) -> str:
|
||||
return str(furl(scheme=scheme, netloc=""))
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Hashable)
|
||||
|
||||
|
||||
class Storage(Generic[T], metaclass=ABCMeta):
|
||||
class Client(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def chunk_size(self) -> int:
|
||||
pass
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
pass
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
pass
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[T, Sequence[UrlToDelete]]:
|
||||
pass
|
||||
|
||||
def get_client(self, base: T, urls: Sequence[UrlToDelete]) -> Client:
|
||||
pass
|
||||
|
||||
|
||||
def delete_urls(urls_query: Q, storage: Storage):
|
||||
to_delete = list(UrlToDelete.objects(urls_query).order_by("url").limit(10000))
|
||||
if not to_delete:
|
||||
return
|
||||
|
||||
grouped_urls = storage.group_urls(to_delete)
|
||||
for base, urls in grouped_urls.items():
|
||||
if not base:
|
||||
msg = f"Invalid {storage.name} url or missing {storage.name} configuration for account"
|
||||
mark_failed(
|
||||
Q(id__in=[url.id for url in urls]), msg,
|
||||
)
|
||||
log.warning(
|
||||
f"Failed to delete {len(urls)} files from {storage.name} due to: {msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
client = storage.get_client(base, urls)
|
||||
except Exception as ex:
|
||||
failed = [url.id for url in urls]
|
||||
mark_retry_failed(failed, reason=str(ex))
|
||||
log.warning(
|
||||
f"Failed to delete {len(failed)} files from {storage.name} due to: {str(ex)}"
|
||||
)
|
||||
continue
|
||||
|
||||
for chunk in chunked_iter(urls, client.chunk_size):
|
||||
paths = []
|
||||
path_to_id_mapping = defaultdict(list)
|
||||
ids_to_delete = set()
|
||||
for url in chunk:
|
||||
try:
|
||||
path = client.get_path(url)
|
||||
except Exception as ex:
|
||||
err = str(ex)
|
||||
mark_failed(Q(id=url.id), err)
|
||||
log.warning(f"Error getting path for {url.url}: {err}")
|
||||
continue
|
||||
|
||||
paths.append(path)
|
||||
path_to_id_mapping[path].append(url.id)
|
||||
ids_to_delete.add(url.id)
|
||||
|
||||
if not paths:
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted_paths, errors = client.delete_many(paths)
|
||||
except Exception as ex:
|
||||
mark_retry_failed([url.id for url in urls], str(ex))
|
||||
log.warning(
|
||||
f"Error deleting {len(paths)} files from {storage.name}: {str(ex)}"
|
||||
)
|
||||
continue
|
||||
|
||||
failed_ids = set()
|
||||
for reason, err_paths in errors.items():
|
||||
error_ids = set(
|
||||
chain.from_iterable(
|
||||
path_to_id_mapping.get(p, []) for p in err_paths
|
||||
)
|
||||
)
|
||||
mark_retry_failed(list(error_ids), reason)
|
||||
log.warning(
|
||||
f"Failed to delete {len(error_ids)} files from {storage.name} storage due to: {reason}"
|
||||
)
|
||||
failed_ids.update(error_ids)
|
||||
|
||||
deleted_ids = set(
|
||||
chain.from_iterable(
|
||||
path_to_id_mapping.get(p, []) for p in deleted_paths
|
||||
)
|
||||
)
|
||||
if deleted_ids:
|
||||
UrlToDelete.objects(id__in=list(deleted_ids)).delete()
|
||||
log.info(
|
||||
f"{len(deleted_ids)} files deleted from {storage.name} storage"
|
||||
)
|
||||
|
||||
missing_ids = ids_to_delete - deleted_ids - failed_ids
|
||||
if missing_ids:
|
||||
mark_retry_failed(list(missing_ids), "Not succeeded")
|
||||
|
||||
|
||||
class FileserverStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
timeout = conf.get("fileserver.timeout_sec", 300)
|
||||
|
||||
def __init__(self, session: requests.Session, host: str):
|
||||
self.session = session
|
||||
self.delete_url = furl(host).add(path="delete_many").url
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 10000
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
path = url.url.strip("/")
|
||||
if not path:
|
||||
raise ValueError("Empty path")
|
||||
|
||||
return path
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
res = self.session.post(
|
||||
url=self.delete_url, json={"files": list(paths)}, timeout=self.timeout
|
||||
)
|
||||
res.raise_for_status()
|
||||
res_data = res.json()
|
||||
return list(res_data.get("deleted", {})), res_data.get("errors", {})
|
||||
|
||||
def __init__(self, company: str, fileserver_host: str = None):
|
||||
fileserver_host = fileserver_host or config.get("hosts.fileserver", None)
|
||||
self.host = fileserver_host.rstrip("/")
|
||||
if not self.host:
|
||||
log.warning(f"Fileserver host not configured")
|
||||
|
||||
def _parse_url_prefix(prefix) -> Tuple[str, str]:
|
||||
url = furl(prefix)
|
||||
host = f"{url.scheme}://{url.netloc}" if url.scheme else None
|
||||
return host, str(url.path).rstrip("/")
|
||||
|
||||
url_prefixes = [
|
||||
_parse_url_prefix(p) for p in conf.get("fileserver.url_prefixes", [])
|
||||
]
|
||||
if not any(self.host == host for host, _ in url_prefixes):
|
||||
url_prefixes.append((self.host, ""))
|
||||
self.url_prefixes = url_prefixes
|
||||
|
||||
self.company = company
|
||||
|
||||
# @classmethod
|
||||
# def validate_fileserver_access(cls, fileserver_host: str):
|
||||
# res = requests.get(
|
||||
# url=fileserver_host
|
||||
# )
|
||||
# res.raise_for_status()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Fileserver"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
if not url.url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = furl(url.url)
|
||||
url_host = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme else None
|
||||
url_path = str(parsed.path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for host, path_prefix in self.url_prefixes:
|
||||
if host and url_host != host:
|
||||
continue
|
||||
if path_prefix and not url_path.startswith(path_prefix + "/"):
|
||||
continue
|
||||
url.url = url_path[len(path_prefix or "") :]
|
||||
return self.host
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[str, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
|
||||
host = base
|
||||
session = requests.session()
|
||||
res = session.get(url=host, timeout=self.Client.timeout)
|
||||
res.raise_for_status()
|
||||
|
||||
return self.Client(session, host)
|
||||
|
||||
|
||||
class AzureStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
def __init__(self, container: ContainerClient):
|
||||
self.container = container
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 256
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
parsed = furl(url.url)
|
||||
if (
|
||||
not parsed.path
|
||||
or not parsed.path.segments
|
||||
or len(parsed.path.segments) <= 1
|
||||
):
|
||||
raise ValueError("No path found following container name")
|
||||
|
||||
return os.path.join(*parsed.path.segments[1:])
|
||||
|
||||
@staticmethod
|
||||
def _path_from_request_url(request_url: str) -> str:
|
||||
try:
|
||||
return furl(request_url).path.segments[-1]
|
||||
except Exception:
|
||||
return request_url
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
try:
|
||||
res = self.container.delete_blobs(*paths)
|
||||
except PartialBatchErrorException as pex:
|
||||
deleted = []
|
||||
errors = defaultdict(list)
|
||||
for part in pex.parts:
|
||||
if 300 >= part.status_code >= 200:
|
||||
deleted.append(self._path_from_request_url(part.request.url))
|
||||
else:
|
||||
errors[part.reason].append(
|
||||
self._path_from_request_url(part.request.url)
|
||||
)
|
||||
return deleted, errors
|
||||
|
||||
return [self._path_from_request_url(part.request.url) for part in res], {}
|
||||
|
||||
def __init__(self, company: str):
|
||||
self.configs = storage_bll.get_azure_settings_for_company(company)
|
||||
self.scheme = "azure"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Azure"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[Tuple]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url.url)
|
||||
if parsed.scheme != self.scheme:
|
||||
return None
|
||||
|
||||
azure_conf = self.configs.get_config_by_uri(url.url)
|
||||
if azure_conf is None:
|
||||
return None
|
||||
|
||||
account_url = parsed.netloc
|
||||
return account_url, azure_conf.container_name
|
||||
except Exception as ex:
|
||||
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
|
||||
return None
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[Tuple, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: Tuple, urls: Sequence[UrlToDelete]) -> Client:
|
||||
account_url, container_name = base
|
||||
sample_url = urls[0].url
|
||||
cfg = self.configs.get_config_by_uri(sample_url)
|
||||
if not cfg or not cfg.account_name or not cfg.account_key:
|
||||
raise ValueError(
|
||||
f"Missing account name or key for Azure Blob Storage "
|
||||
f"account: {account_url}, container: {container_name}"
|
||||
)
|
||||
|
||||
return self.Client(
|
||||
ContainerClient(
|
||||
account_url=account_url,
|
||||
container_name=cfg.container_name,
|
||||
credential={
|
||||
"account_name": cfg.account_name,
|
||||
"account_key": cfg.account_key,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AWSStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
def __init__(self, base_url: str, container: AWSBucket):
|
||||
self.container = container
|
||||
self.base_url = base_url
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 1000
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
""" Normalize remote path. Remove any prefix that is already handled by the container """
|
||||
path = url.url
|
||||
if path.startswith(self.base_url):
|
||||
path = path[len(self.base_url) :]
|
||||
path = path.lstrip("/")
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def _path_from_request_url(request_url: str) -> str:
|
||||
try:
|
||||
return furl(request_url).path.segments[-1]
|
||||
except Exception:
|
||||
return request_url
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
res = self.container.delete_objects(
|
||||
Delete={"Objects": [{"Key": p} for p in paths]}
|
||||
)
|
||||
errors = defaultdict(list)
|
||||
for err in res.get("Errors", []):
|
||||
msg = err.get("Message", "")
|
||||
errors[msg].append(err.get("Key"))
|
||||
|
||||
return [d.get("Key") for d in res.get("Deleted", [])], errors
|
||||
|
||||
def __init__(self, company: str):
|
||||
self.configs = storage_bll.get_aws_settings_for_company(company)
|
||||
self.scheme = "s3"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "AWS"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url.url)
|
||||
if parsed.scheme != self.scheme:
|
||||
return None
|
||||
|
||||
s3_conf = self.configs.get_config_by_uri(url.url)
|
||||
if s3_conf is None:
|
||||
return None
|
||||
|
||||
s3_bucket = s3_conf.bucket
|
||||
if not s3_bucket:
|
||||
parts = Path(parsed.path.strip("/")).parts
|
||||
if parts:
|
||||
s3_bucket = parts[0]
|
||||
return "/".join(filter(None, ("s3:/", s3_conf.host, s3_bucket)))
|
||||
except Exception as ex:
|
||||
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
|
||||
return None
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[str, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
|
||||
sample_url = urls[0].url
|
||||
cfg = self.configs.get_config_by_uri(sample_url)
|
||||
boto_kwargs = {
|
||||
"endpoint_url": (("https://" if cfg.secure else "http://") + cfg.host)
|
||||
if cfg.host
|
||||
else None,
|
||||
"use_ssl": cfg.secure,
|
||||
"verify": cfg.verify,
|
||||
}
|
||||
name = base[len(scheme_prefix(self.scheme)) :]
|
||||
bucket_name = name[len(cfg.host) + 1 :] if cfg.host else name
|
||||
if not cfg.use_credentials_chain:
|
||||
if not cfg.key or not cfg.secret:
|
||||
raise ValueError(
|
||||
f"Missing key or secret for AWS S3 host: {cfg.host}, bucket: {str(bucket_name)}"
|
||||
)
|
||||
|
||||
boto_kwargs["aws_access_key_id"] = cfg.key
|
||||
boto_kwargs["aws_secret_access_key"] = cfg.secret
|
||||
if cfg.token:
|
||||
boto_kwargs["aws_session_token"] = cfg.token
|
||||
|
||||
return self.Client(
|
||||
base, boto3.resource("s3", **boto_kwargs).Bucket(bucket_name)
|
||||
)
|
||||
|
||||
|
||||
class GoogleCloudStorage(Storage):
|
||||
class Client(Storage.Client):
|
||||
def __init__(self, base_url: str, container: google_storage.Bucket):
|
||||
self.container = container
|
||||
self.base_url = base_url
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return 100
|
||||
|
||||
def get_path(self, url: UrlToDelete) -> str:
|
||||
""" Normalize remote path. Remove any prefix that is already handled by the container """
|
||||
path = url.url
|
||||
if path.startswith(self.base_url):
|
||||
path = path[len(self.base_url) :]
|
||||
path = path.lstrip("/")
|
||||
return path
|
||||
|
||||
def delete_many(
|
||||
self, paths: Sequence[str]
|
||||
) -> Tuple[Sequence[str], Mapping[str, Sequence[str]]]:
|
||||
not_found = set()
|
||||
|
||||
def error_callback(blob: google_storage.Blob):
|
||||
not_found.add(blob.name)
|
||||
|
||||
self.container.delete_blobs(
|
||||
[self.container.blob(p) for p in paths], on_error=error_callback,
|
||||
)
|
||||
errors = {"Not found": list(not_found)} if not_found else {}
|
||||
return list(set(paths) - not_found), errors
|
||||
|
||||
def __init__(self, company: str):
|
||||
self.configs = storage_bll.get_gs_settings_for_company(company)
|
||||
self.scheme = "gs"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Google Storage"
|
||||
|
||||
def _resolve_base_url(self, url: UrlToDelete) -> Optional[str]:
|
||||
"""
|
||||
For the url return the base_url containing schema, optional host and bucket name
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url.url)
|
||||
if parsed.scheme != self.scheme:
|
||||
return None
|
||||
|
||||
gs_conf = self.configs.get_config_by_uri(url.url)
|
||||
if gs_conf is None:
|
||||
return None
|
||||
|
||||
return str(furl(scheme=parsed.scheme, netloc=gs_conf.bucket))
|
||||
except Exception as ex:
|
||||
log.warning(f"Error resolving base url for {url.url}: " + str(ex))
|
||||
return None
|
||||
|
||||
def group_urls(
|
||||
self, urls: Sequence[UrlToDelete]
|
||||
) -> Mapping[str, Sequence[UrlToDelete]]:
|
||||
return bucketize(urls, key=self._resolve_base_url)
|
||||
|
||||
def get_client(self, base: str, urls: Sequence[UrlToDelete]) -> Client:
|
||||
sample_url = urls[0].url
|
||||
cfg = self.configs.get_config_by_uri(sample_url)
|
||||
if cfg.credentials_json:
|
||||
from google.oauth2 import service_account
|
||||
|
||||
credentials = service_account.Credentials.from_service_account_file(
|
||||
cfg.credentials_json
|
||||
)
|
||||
else:
|
||||
credentials = None
|
||||
|
||||
bucket_name = base[len(scheme_prefix(self.scheme)) :]
|
||||
return self.Client(
|
||||
base,
|
||||
google_storage.Client(project=cfg.project, credentials=credentials).bucket(
|
||||
bucket_name
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_delete_loop(fileserver_host: str):
|
||||
storage_helpers = {
|
||||
StorageType.fileserver: partial(
|
||||
FileserverStorage, fileserver_host=fileserver_host
|
||||
),
|
||||
StorageType.s3: AWSStorage,
|
||||
StorageType.azure: AzureStorage,
|
||||
StorageType.gs: GoogleCloudStorage,
|
||||
}
|
||||
while True:
|
||||
now = datetime.utcnow()
|
||||
urls_query = (
|
||||
Q(status__ne=DeletionStatus.failed)
|
||||
& Q(retry_count__lt=max_retries)
|
||||
& (
|
||||
Q(last_failure_time__exists=False)
|
||||
| Q(last_failure_time__lt=now - retry_timeout)
|
||||
)
|
||||
)
|
||||
|
||||
url_to_delete: UrlToDelete = UrlToDelete.objects(
|
||||
urls_query & Q(storage_type__in=list(storage_helpers))
|
||||
).order_by("retry_count").limit(1).first()
|
||||
if not url_to_delete:
|
||||
sleep(10)
|
||||
continue
|
||||
|
||||
company = url_to_delete.company
|
||||
user = url_to_delete.user
|
||||
storage_type = url_to_delete.storage_type
|
||||
log.info(
|
||||
f"Deleting {storage_type} objects for company: {company}, user: {user}"
|
||||
)
|
||||
company_storage_urls_query = urls_query & Q(
|
||||
company=company, storage_type=storage_type,
|
||||
)
|
||||
delete_urls(
|
||||
urls_query=company_storage_urls_query,
|
||||
storage=storage_helpers[storage_type](company=company),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
|
||||
parser.add_argument(
|
||||
"--fileserver-host", "-fh", help="Fileserver host address", type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
db.initialize()
|
||||
run_delete_loop(args.fileserver_host)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,8 +1,10 @@
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
from inspect import signature
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
import pymongo.database
|
||||
from mongoengine.connection import get_db
|
||||
from packaging.version import Version, parse
|
||||
|
||||
@@ -80,8 +82,15 @@ def _apply_migrations(log: Logger):
|
||||
if not func:
|
||||
continue
|
||||
try:
|
||||
sig = signature(func)
|
||||
kwargs = {}
|
||||
if len(sig.parameters) == 2:
|
||||
name, param = list(sig.parameters.items())[-1]
|
||||
key = name.replace("_", "-")
|
||||
if issubclass(param.annotation, pymongo.database.Database) and key in dbs:
|
||||
kwargs[name] = get_db(key)
|
||||
log.info(f"Applying {script.stem}/{func_name}()")
|
||||
func(get_db(alias))
|
||||
func(get_db(alias), **kwargs)
|
||||
except Exception:
|
||||
log.exception(f"Failed applying {script}:{func_name}()")
|
||||
raise ValueError(
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing import (
|
||||
Callable,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from uuid import uuid4
|
||||
from uuid import uuid4, UUID, uuid5
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import mongoengine
|
||||
@@ -72,6 +72,8 @@ class PrePopulate:
|
||||
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
_name_guid_ns = UUID("bda3acc1-e612-506c-bade-80071b6cf039")
|
||||
_example_id_prefix = "e-"
|
||||
task_cls: Type[Task]
|
||||
project_cls: Type[Project]
|
||||
model_cls: Type[Model]
|
||||
@@ -691,17 +693,56 @@ class PrePopulate:
|
||||
continue
|
||||
yield clean
|
||||
|
||||
@staticmethod
|
||||
def _new_id(_):
|
||||
return str(uuid4()).replace("-", "")
|
||||
|
||||
@classmethod
|
||||
def _hash_id(cls, name: str):
|
||||
return str(uuid5(cls._name_guid_ns, name)).replace("-", "")
|
||||
|
||||
@classmethod
|
||||
def _example_id(cls, orig_id: str):
|
||||
if not orig_id or orig_id.startswith(cls._example_id_prefix):
|
||||
return orig_id
|
||||
|
||||
return cls._example_id_prefix + orig_id
|
||||
|
||||
@classmethod
|
||||
def _private_id(cls, orig_id: str):
|
||||
if not orig_id or not orig_id.startswith(cls._example_id_prefix):
|
||||
return orig_id
|
||||
|
||||
return orig_id[len(cls._example_id_prefix) :]
|
||||
|
||||
@classmethod
|
||||
def _generate_new_ids(
|
||||
cls, reader: ZipFile, entity_files: Sequence
|
||||
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
|
||||
) -> Mapping[str, str]:
|
||||
if not metadata or not any(
|
||||
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
|
||||
):
|
||||
return {}
|
||||
|
||||
ids = {}
|
||||
for entity_file in entity_files:
|
||||
with reader.open(entity_file) as f:
|
||||
is_project = splitext(entity_file.orig_filename)[0].endswith(".Project")
|
||||
if metadata.get("new_ids"):
|
||||
id_func = cls._new_id
|
||||
elif metadata.get("example_ids"):
|
||||
id_func = cls._example_id if not is_project else cls._hash_id
|
||||
elif metadata.get("private_ids"):
|
||||
id_func = cls._private_id if not is_project else cls._new_id
|
||||
for item in cls.json_lines(f):
|
||||
orig_id = json.loads(item).get("_id")
|
||||
doc = json.loads(item)
|
||||
orig_id = doc.get("_id")
|
||||
if orig_id:
|
||||
ids[orig_id] = str(uuid4()).replace("-", "")
|
||||
ids[orig_id] = (
|
||||
id_func(orig_id)
|
||||
if id_func != cls._hash_id
|
||||
else id_func(doc.get("name"))
|
||||
)
|
||||
return ids
|
||||
|
||||
@classmethod
|
||||
@@ -725,11 +766,7 @@ class PrePopulate:
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
]
|
||||
metadata = metadata or {}
|
||||
old_to_new_ids = (
|
||||
cls._generate_new_ids(reader, entity_files)
|
||||
if metadata.get("new_ids")
|
||||
else {}
|
||||
)
|
||||
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
|
||||
tasks = []
|
||||
for entity_file in entity_files:
|
||||
with reader.open(entity_file) as f:
|
||||
@@ -923,9 +960,7 @@ class PrePopulate:
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
def _import_events(
|
||||
cls, f: IO[bytes], company_id: str, _, task_id: str
|
||||
):
|
||||
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
@@ -933,5 +968,5 @@ class PrePopulate:
|
||||
ev["task"] = task_id
|
||||
ev["company_id"] = company_id
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
company_id, events=events, worker="", allow_locked=True
|
||||
)
|
||||
|
||||
@@ -53,6 +53,7 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
name=user_name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
created=datetime.utcnow(),
|
||||
).save()
|
||||
|
||||
return user_id
|
||||
|
||||
15
apiserver/mongo/migrations/1_7_0.py
Normal file
15
apiserver/mongo/migrations/1_7_0.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
|
||||
def migrate_backend(db: Database, auth_db: Database):
|
||||
users: Collection = db["user"]
|
||||
auth_users: Collection = auth_db["user"]
|
||||
created_field = "created"
|
||||
for doc in users.find({created_field: {"$exists": False}}):
|
||||
auth_user = auth_users.find_one({"_id": doc["_id"]}, projection=[created_field])
|
||||
if not auth_user or created_field not in auth_user:
|
||||
continue
|
||||
users.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {created_field: auth_user[created_field]}}
|
||||
)
|
||||
17
apiserver/mongo/migrations/1_9_0.py
Normal file
17
apiserver/mongo/migrations/1_9_0.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging as log
|
||||
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
from pymongo.errors import OperationFailure
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
"""
|
||||
Drop task text index so that the new one including reports field is created
|
||||
"""
|
||||
tasks: Collection = db["task"]
|
||||
try:
|
||||
tasks.drop_index("backend-db.task.main_text_index")
|
||||
except OperationFailure as ex:
|
||||
log.warning(f"Could not delete task text index due to: {str(ex)}")
|
||||
pass
|
||||
@@ -1,7 +1,10 @@
|
||||
attrs>=19.1.0
|
||||
attrs>=22.1.0
|
||||
azure-storage-blob>=12.13.1
|
||||
bcrypt>=3.1.4
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
boto3-stubs[s3]>=1.24.35
|
||||
clearml>=1.6.0,<1.7.0
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch==7.13.3
|
||||
fastjsonschema>=2.8
|
||||
@@ -10,13 +13,15 @@ flask-cors>=3.0.5
|
||||
flask>=0.12.2
|
||||
funcsigs==1.0.2
|
||||
furl>=2.0.0
|
||||
google-cloud-storage==2.0.0
|
||||
protobuf==3.19.5
|
||||
gunicorn>=19.7.1
|
||||
humanfriendly==4.18
|
||||
jinja2==2.11.3
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.10.0
|
||||
mongoengine==0.23.1
|
||||
mongoengine==0.24.2
|
||||
nested_dict>=1.61
|
||||
packaging==20.3
|
||||
psutil>=5.6.5
|
||||
@@ -24,9 +29,8 @@ pyhocon>=0.3.35
|
||||
pyjwt>=2.4.0
|
||||
pymongo[srv]==3.12.0
|
||||
python-rapidjson>=0.6.3
|
||||
redis==3.5.3
|
||||
redis==4.4.4
|
||||
redis-py-cluster>=2.1.3
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
|
||||
@@ -15,6 +15,35 @@ metadata_item {
|
||||
}
|
||||
}
|
||||
}
|
||||
task_status_enum {
|
||||
type: string
|
||||
enum: [
|
||||
created
|
||||
queued
|
||||
in_progress
|
||||
stopped
|
||||
published
|
||||
publishing
|
||||
closed
|
||||
failed
|
||||
completed
|
||||
unknown
|
||||
]
|
||||
}
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
credentials {
|
||||
type: object
|
||||
properties {
|
||||
|
||||
106
apiserver/schema/services/_events_common.conf
Normal file
106
apiserver/schema/services/_events_common.conf
Normal file
@@ -0,0 +1,106 @@
|
||||
scalar_key_enum {
|
||||
type: string
|
||||
enum: [
|
||||
iter
|
||||
timestamp
|
||||
iso_time
|
||||
]
|
||||
}
|
||||
metric_variants {
|
||||
type: object
|
||||
properties {
|
||||
metric {
|
||||
description: The metric name
|
||||
type: string
|
||||
}
|
||||
variants {
|
||||
type: array
|
||||
description: The names of the metric variants
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images_response_task_metrics {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: Task ID
|
||||
}
|
||||
iterations {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
iter {
|
||||
type: integer
|
||||
description: Iteration number
|
||||
}
|
||||
events {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
description: Debug image event
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images_response {
|
||||
type: object
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: "Debug image events grouped by tasks and iterations"
|
||||
items {"$ref": "#/definitions/debug_images_response_task_metrics"}
|
||||
}
|
||||
}
|
||||
}
|
||||
plots_response_task_metrics {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: Task ID
|
||||
}
|
||||
iterations {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
iter {
|
||||
type: integer
|
||||
description: Iteration number
|
||||
}
|
||||
events {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
description: Plot event
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
plots_response {
|
||||
type: object
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: "Plot events grouped by tasks and iterations"
|
||||
items {"$ref": "#/definitions/plots_response_task_metrics"}
|
||||
}
|
||||
}
|
||||
}
|
||||
506
apiserver/schema/services/_tasks_common.conf
Normal file
506
apiserver/schema/services/_tasks_common.conf
Normal file
@@ -0,0 +1,506 @@
|
||||
include "_common.conf"
|
||||
task_type_enum {
|
||||
type: string
|
||||
enum: [
|
||||
training
|
||||
testing
|
||||
inference
|
||||
data_processing
|
||||
application
|
||||
monitor
|
||||
controller
|
||||
optimizer
|
||||
service
|
||||
qc
|
||||
custom
|
||||
]
|
||||
}
|
||||
script {
|
||||
type: object
|
||||
properties {
|
||||
binary {
|
||||
description: "Binary to use when running the script"
|
||||
type: string
|
||||
default: python
|
||||
}
|
||||
repository {
|
||||
description: "Name of the repository where the script is located"
|
||||
type: string
|
||||
}
|
||||
tag {
|
||||
description: "Repository tag"
|
||||
type: string
|
||||
}
|
||||
branch {
|
||||
description: "Repository branch id If not provided and tag not provided, default repository branch is used."
|
||||
type: string
|
||||
}
|
||||
version_num {
|
||||
description: "Version (changeset) number. Optional (default is head version) Unused if tag is provided."
|
||||
type: string
|
||||
}
|
||||
entry_point {
|
||||
description: "Path to execute within the repository"
|
||||
type: string
|
||||
}
|
||||
working_dir {
|
||||
description: "Path to the folder from which to run the script Default - root folder of repository"
|
||||
type: string
|
||||
}
|
||||
requirements {
|
||||
description: "A JSON object containing requirements strings by key"
|
||||
type: object
|
||||
}
|
||||
diff {
|
||||
description: "Uncommitted changes found in the repository when task was run"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
model_type_enum {
|
||||
type: string
|
||||
enum: ["input", "output"]
|
||||
}
|
||||
task_model_item {
|
||||
type: object
|
||||
required: [ name, model]
|
||||
properties {
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "The model ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
type: object
|
||||
properties {
|
||||
destination {
|
||||
description: "Storage id. This is where output files will be stored."
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "Model id."
|
||||
type: string
|
||||
}
|
||||
result {
|
||||
description: "Task result. Values: 'success', 'failure'"
|
||||
type: string
|
||||
}
|
||||
error {
|
||||
description: "Last error text"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_execution_progress_enum {
|
||||
type: string
|
||||
enum: [
|
||||
unknown
|
||||
running
|
||||
stopping
|
||||
stopped
|
||||
]
|
||||
}
|
||||
artifact_type_data {
|
||||
type: object
|
||||
properties {
|
||||
preview {
|
||||
description: "Description or textual data"
|
||||
type: string
|
||||
}
|
||||
content_type {
|
||||
description: "System defined raw data content type"
|
||||
type: string
|
||||
}
|
||||
data_hash {
|
||||
description: "Hash of raw data, without any headers or descriptive parts"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact_mode_enum {
|
||||
type: string
|
||||
enum: [
|
||||
input
|
||||
output
|
||||
]
|
||||
default: output
|
||||
}
|
||||
artifact {
|
||||
type: object
|
||||
required: [key, type]
|
||||
properties {
|
||||
key {
|
||||
description: "Entry key"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "System defined type"
|
||||
type: string
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
"$ref": "#/definitions/artifact_mode_enum"
|
||||
}
|
||||
uri {
|
||||
description: "Raw data location"
|
||||
type: string
|
||||
}
|
||||
content_size {
|
||||
description: "Raw data length in bytes"
|
||||
type: integer
|
||||
}
|
||||
hash {
|
||||
description: "Hash of entire raw data"
|
||||
type: string
|
||||
}
|
||||
timestamp {
|
||||
description: "Epoch time when artifact was created"
|
||||
type: integer
|
||||
}
|
||||
type_data {
|
||||
description: "Additional fields defined by the system"
|
||||
"$ref": "#/definitions/artifact_type_data"
|
||||
}
|
||||
display_data {
|
||||
description: "User-defined list of key/value pairs, sorted"
|
||||
type: array
|
||||
items {
|
||||
type: array
|
||||
items {
|
||||
type: string # can also be a number... TODO: upgrade the generator
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact_id {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
key {
|
||||
description: "Entry key"
|
||||
type: string
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
"$ref": "#/definitions/artifact_mode_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
task_models {
|
||||
type: object
|
||||
properties {
|
||||
input {
|
||||
description: "The list of task input models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
|
||||
}
|
||||
output {
|
||||
description: "The list of task output models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
queue {
|
||||
description: "Queue ID where task was queued."
|
||||
type: string
|
||||
}
|
||||
parameters {
|
||||
description: "Json object containing the Task parameters"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
model {
|
||||
description: "Execution input model ID Not applicable for Register (Import) tasks"
|
||||
type: string
|
||||
}
|
||||
model_desc {
|
||||
description: "Json object representing the Model descriptors"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
model_labels {
|
||||
description: """Json object representing the ids of the labels in the model.
|
||||
The keys are the layers' names and the values are the IDs.
|
||||
Not applicable for Register (Import) tasks.
|
||||
Mandatory for Training tasks"""
|
||||
type: object
|
||||
additionalProperties: { type: integer }
|
||||
}
|
||||
framework {
|
||||
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
|
||||
type: string
|
||||
}
|
||||
docker_cmd {
|
||||
description: "Command for running docker script for the execution of the task"
|
||||
type: string
|
||||
}
|
||||
artifacts {
|
||||
description: "Task artifacts"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/artifact" }
|
||||
}
|
||||
}
|
||||
}
|
||||
last_metrics_event {
|
||||
type: object
|
||||
properties {
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Variant name"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Last value reported"
|
||||
type: number
|
||||
}
|
||||
min_value {
|
||||
description: "Minimum value reported"
|
||||
type: number
|
||||
}
|
||||
min_value_iteration {
|
||||
description: "The iteration at which the minimum value was reported"
|
||||
type: integer
|
||||
}
|
||||
max_value {
|
||||
description: "Maximum value reported"
|
||||
type: number
|
||||
}
|
||||
max_value_iteration {
|
||||
description: "The iteration at which the maximum value was reported"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
last_metrics_variants {
|
||||
type: object
|
||||
description: "Last metric events, one for each variant hash"
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/last_metrics_event"
|
||||
}
|
||||
}
|
||||
params_item {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. The combination of section and name should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
configuration_item {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Name of the parameter. Should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
section_params {
|
||||
description: "Task section params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/params_item"
|
||||
}
|
||||
}
|
||||
task {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Task Name"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company ID"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of task. Values: 'training', 'testing'"
|
||||
"$ref": "#/definitions/task_type_enum"
|
||||
}
|
||||
status {
|
||||
description: ""
|
||||
"$ref": "#/definitions/task_status_enum"
|
||||
}
|
||||
comment {
|
||||
description: "Free text comment"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Task creation time (UTC) "
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
started {
|
||||
description: "Task start time (UTC)"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
completed {
|
||||
description: "Task end time (UTC)"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
active_duration {
|
||||
description: "Task duration time (seconds)"
|
||||
type: integer
|
||||
}
|
||||
parent {
|
||||
description: "Parent task id"
|
||||
type: string
|
||||
}
|
||||
project {
|
||||
description: "Project ID of the project to which this task is assigned"
|
||||
type: string
|
||||
}
|
||||
output {
|
||||
description: "Task output params"
|
||||
"$ref": "#/definitions/output"
|
||||
}
|
||||
execution {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
// TODO: will be removed
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status_changed {
|
||||
description: "Last status change time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
status_message {
|
||||
description: "free text string representing info about the status"
|
||||
type: string
|
||||
}
|
||||
status_reason {
|
||||
description: "Reason for last status change"
|
||||
type: string
|
||||
}
|
||||
published {
|
||||
description: "Task publish time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_worker {
|
||||
description: "ID of last worker that handled the task"
|
||||
type: string
|
||||
}
|
||||
last_worker_report {
|
||||
description: "Last time a worker reported while working on this task"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last time this task was created, edited, changed or events for this task were reported"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_change {
|
||||
description: "Last time any update was done to the task"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_iteration {
|
||||
description: "Last iteration reported for this task"
|
||||
type: integer
|
||||
}
|
||||
last_metrics {
|
||||
description: "Last metric variants (hash to events), one for each metric hash"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/last_metrics_variants"
|
||||
}
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,6 @@
|
||||
_description : "Provides an API for running tasks to report events collected by the system."
|
||||
_definitions {
|
||||
metric_variants {
|
||||
type: object
|
||||
metric {
|
||||
description: The metric name
|
||||
type: string
|
||||
}
|
||||
variants {
|
||||
type: array
|
||||
description: The names of the metric variants
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
include "_events_common.conf"
|
||||
metrics_scalar_event {
|
||||
description: "Used for reporting scalar metrics during training task"
|
||||
type: object
|
||||
@@ -164,14 +153,6 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_key_enum {
|
||||
type: string
|
||||
enum: [
|
||||
iter
|
||||
timestamp
|
||||
iso_time
|
||||
]
|
||||
}
|
||||
log_level_enum {
|
||||
type: string
|
||||
enum: [
|
||||
@@ -260,90 +241,6 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images_response_task_metrics {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: Task ID
|
||||
}
|
||||
iterations {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
iter {
|
||||
type: integer
|
||||
description: Iteration number
|
||||
}
|
||||
events {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
description: Debug image event
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images_response {
|
||||
type: object
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: "Debug image events grouped by tasks and iterations"
|
||||
items {"$ref": "#/definitions/debug_images_response_task_metrics"}
|
||||
}
|
||||
}
|
||||
}
|
||||
plots_response_task_metrics {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: Task ID
|
||||
}
|
||||
iterations {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
iter {
|
||||
type: integer
|
||||
description: Iteration number
|
||||
}
|
||||
events {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
description: Plot event
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
plots_response {
|
||||
type: object
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: "Plot events grouped by tasks and iterations"
|
||||
items {"$ref": "#/definitions/plots_response_task_metrics"}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_image_sample_response {
|
||||
type: object
|
||||
properties {
|
||||
@@ -372,17 +269,18 @@ _definitions {
|
||||
type: string
|
||||
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
|
||||
}
|
||||
event {
|
||||
type: object
|
||||
description: "Plot event"
|
||||
events {
|
||||
description: "Plot events"
|
||||
type: array
|
||||
items { type: object}
|
||||
}
|
||||
min_iteration {
|
||||
type: integer
|
||||
description: "minimal valid iteration for the variant"
|
||||
description: "minimal valid iteration for the metric"
|
||||
}
|
||||
max_iteration {
|
||||
type: integer
|
||||
description: "maximal valid iteration for the variant"
|
||||
description: "maximal valid iteration for the metric"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -405,13 +303,27 @@ add {
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
"2.22": ${add."2.1"} {
|
||||
request.properties {
|
||||
model_event {
|
||||
type: boolean
|
||||
description: If set then the event is for a model. Otherwise for a task. Cannot be used with task log events. If used in batch then all the events should be marked the same
|
||||
default: false
|
||||
}
|
||||
allow_locked {
|
||||
type: boolean
|
||||
description: Allow adding events to published tasks or models
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
add_batch {
|
||||
"2.1" {
|
||||
description: "Adds a batch of events in a single call (json-lines format, stream-friendly)"
|
||||
batch_request: {
|
||||
action: add
|
||||
version: 1.5
|
||||
version: 2.1
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
@@ -422,10 +334,16 @@ add_batch {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${add_batch."2.1"} {
|
||||
batch_request: {
|
||||
action: add
|
||||
version: 2.22
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_for_task {
|
||||
"2.1" {
|
||||
description: "Delete all task event. *This cannot be undone!*"
|
||||
description: "Delete all task events. *This cannot be undone!*"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
@@ -454,6 +372,37 @@ delete_for_task {
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_for_model {
|
||||
"2.22" {
|
||||
description: "Delete all model events. *This cannot be undone!*"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
model
|
||||
]
|
||||
properties {
|
||||
model {
|
||||
type: string
|
||||
description: "Model ID"
|
||||
}
|
||||
allow_locked {
|
||||
type: boolean
|
||||
description: "Allow deleting events even if the model is locked"
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
type: boolean
|
||||
description: "Number of deleted events"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images {
|
||||
"2.1" {
|
||||
description: "Get all debug images of a task"
|
||||
@@ -495,7 +444,7 @@ debug_images {
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -548,6 +497,13 @@ debug_images {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${debug_images."2.14"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
plots {
|
||||
"2.20" {
|
||||
@@ -565,7 +521,7 @@ plots {
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
description: "Max number of latest iterations for which to return plots"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
@@ -583,6 +539,13 @@ plots {
|
||||
}
|
||||
response {"$ref": "#/definitions/plots_response"}
|
||||
}
|
||||
"2.22": ${plots."2.20"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model plots. Otherwise task plots
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_debug_image_sample {
|
||||
"2.12": {
|
||||
@@ -626,6 +589,13 @@ get_debug_image_sample {
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"2.22": ${get_debug_image_sample."2.20"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model debug images. Otherwise task debug images
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
next_debug_image_sample {
|
||||
"2.12": {
|
||||
@@ -651,13 +621,25 @@ next_debug_image_sample {
|
||||
}
|
||||
response {"$ref": "#/definitions/debug_image_sample_response"}
|
||||
}
|
||||
"2.22": ${next_debug_image_sample."2.12"} {
|
||||
request.properties.next_iteration {
|
||||
type: boolean
|
||||
default: false
|
||||
description: If set then navigate to the next/previous iteration
|
||||
}
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model debug images. Otherwise task debug images
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_plot_sample {
|
||||
"2.20": {
|
||||
description: "Return the plot per metric and variant for the provided iteration"
|
||||
description: "Return plots for the provided iteration"
|
||||
request {
|
||||
type: object
|
||||
required: [task, metric, variant]
|
||||
required: [task, metric]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
@@ -667,10 +649,6 @@ get_plot_sample {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Metric variant"
|
||||
type: string
|
||||
}
|
||||
iteration {
|
||||
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
|
||||
type: integer
|
||||
@@ -692,10 +670,17 @@ get_plot_sample {
|
||||
}
|
||||
response {"$ref": "#/definitions/plot_sample_response"}
|
||||
}
|
||||
"2.22": ${get_plot_sample."2.20"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model plots. Otherwise task plots
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
next_plot_sample {
|
||||
"2.20": {
|
||||
description: "Get the plot for the next variant for the same iteration or for the next iteration"
|
||||
description: "Get the plot for the next metric for the same iteration or for the next iteration"
|
||||
request {
|
||||
type: object
|
||||
required: [task, scroll_id]
|
||||
@@ -710,13 +695,25 @@ next_plot_sample {
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration.
|
||||
Otherwise next variant event from the current iteration or first variant event from the next iteration"""
|
||||
description: """If set then get the either previous metric events from the current iteration or (if does not exist) the last metric events from the previous iteration.
|
||||
Otherwise next metric events from the current iteration or first metric events from the next iteration"""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {"$ref": "#/definitions/plot_sample_response"}
|
||||
}
|
||||
"2.22": ${next_plot_sample."2.20"} {
|
||||
request.properties.next_iteration {
|
||||
type: boolean
|
||||
default: false
|
||||
description: If set then navigate to the next/previous iteration
|
||||
}
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model plots. Otherwise task plots
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_metrics{
|
||||
"2.7": {
|
||||
@@ -749,6 +746,13 @@ get_task_metrics{
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_task_metrics."2.7"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then get metrics from model events. Otherwise from task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_log {
|
||||
"1.5" {
|
||||
@@ -848,7 +852,7 @@ get_task_log {
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -902,7 +906,7 @@ get_task_log {
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of log events available for this query"
|
||||
description: "Total number of log events available for this query. In case there are more than 10000 events it is set to 10000"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -957,7 +961,7 @@ get_task_events {
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -966,6 +970,13 @@ get_task_events {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_task_events."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then get retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
download_task_log {
|
||||
@@ -1016,7 +1027,7 @@ get_task_plots {
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
description: "Max number of latest iterations for which to return plots"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -1040,7 +1051,7 @@ get_task_plots {
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -1067,6 +1078,13 @@ get_task_plots {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.22": ${get_task_plots."2.16"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_plots {
|
||||
"2.1" {
|
||||
@@ -1087,7 +1105,7 @@ get_multi_task_plots {
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
description: "Max number of latest iterations for which to return plots"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -1108,7 +1126,7 @@ get_multi_task_plots {
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
description: "Total number of results available for this query. In case there are more than 10000 results it is set to 10000"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
@@ -1124,6 +1142,13 @@ get_multi_task_plots {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.22": ${get_multi_task_plots."2.16"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
@@ -1152,6 +1177,13 @@ get_vector_metrics_and_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_vector_metrics_and_variants."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
vector_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
@@ -1190,6 +1222,13 @@ vector_metrics_iter_histogram {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${vector_metrics_iter_histogram."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
@@ -1243,6 +1282,13 @@ scalar_metrics_iter_histogram {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${scalar_metrics_iter_histogram."2.14"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_task_scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
@@ -1254,7 +1300,7 @@ multi_task_scalar_metrics_iter_histogram {
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs. Maximum amount of tasks is 10"
|
||||
description: "List of task Task IDs. Maximum amount of tasks is 100"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
@@ -1282,6 +1328,13 @@ multi_task_scalar_metrics_iter_histogram {
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
"2.22": ${multi_task_scalar_metrics_iter_histogram."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_single_value_metrics {
|
||||
"2.20" {
|
||||
@@ -1331,6 +1384,13 @@ get_task_single_value_metrics {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_task_single_value_metrics."2.20"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_latest_scalar_values {
|
||||
"2.1" {
|
||||
@@ -1410,6 +1470,13 @@ get_scalar_metrics_and_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_scalar_metrics_and_variants."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_scalar_metric_data {
|
||||
"2.1" {
|
||||
@@ -1459,6 +1526,13 @@ get_scalar_metric_data {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.22": ${get_scalar_metric_data."2.16"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_raw {
|
||||
"2.16" {
|
||||
@@ -1523,6 +1597,13 @@ scalar_metrics_iter_raw {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${scalar_metrics_iter_raw."2.16"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
clear_scroll {
|
||||
"2.18" {
|
||||
@@ -1578,4 +1659,4 @@ clear_task_log {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,6 +241,15 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.23": ${get_all_ex."2.20"} {
|
||||
request.properties {
|
||||
allow_public {
|
||||
description: "Allow public models to be returned in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -320,6 +329,14 @@ get_all {
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
last_update {
|
||||
description: "List of last_update constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
pattern: "^(>=|>|<=|<)?.*$"
|
||||
}
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
@@ -968,7 +985,7 @@ move {
|
||||
items { type: string }
|
||||
}
|
||||
project {
|
||||
description: "Target project ID. If not provided, `project_name` must be provided."
|
||||
description: "Target project ID. If not provided, `project_name` must be provided. Use null for the root project"
|
||||
type: string
|
||||
}
|
||||
project_name {
|
||||
|
||||
@@ -162,4 +162,38 @@ get_entities_count {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.22": ${get_entities_count."2.20"} {
|
||||
request.properties {
|
||||
search_hidden {
|
||||
description: "If set to 'true' then hidden projects and tasks are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
active_users {
|
||||
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.23": ${get_entities_count."2.22"} {
|
||||
request.properties {
|
||||
reports {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
description: Search criteria for reports
|
||||
}
|
||||
allow_public {
|
||||
description: "Allow public entities to be counted in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
response.properties {
|
||||
reports {
|
||||
type: integer
|
||||
description: The number of reports matching the criteria
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,11 +46,6 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
@@ -66,12 +61,26 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
last_update {
|
||||
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
|
||||
description: "Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
stats_datasets {
|
||||
type: object
|
||||
properties {
|
||||
count {
|
||||
description: Number of datasets
|
||||
type: integer
|
||||
}
|
||||
tags {
|
||||
description: Dataset tags
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
stats_status_count {
|
||||
type: object
|
||||
properties {
|
||||
@@ -146,6 +155,10 @@ _definitions {
|
||||
description: "Stats for archived tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
datasets {
|
||||
description: "Stats for childrent datasets"
|
||||
"$ref": "#/definitions/stats_datasets"
|
||||
}
|
||||
}
|
||||
}
|
||||
projects_get_all_response_single {
|
||||
@@ -181,11 +194,6 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
@@ -200,7 +208,16 @@ _definitions {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
last_update {
|
||||
description: "Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
// extra properties
|
||||
hidden {
|
||||
description: "Returned if the search_hidden flag was specified in the get_all_ex call and the project is hidden"
|
||||
type: boolean
|
||||
}
|
||||
stats {
|
||||
description: "Additional project stats"
|
||||
"$ref": "#/definitions/stats"
|
||||
@@ -222,6 +239,10 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
own_datasets {
|
||||
description: "The amount of datasets/hyperdatasers under this project (without children projects). Returned if 'check_own_contents' flag is set in the request and children_type is set to 'dataset' or 'hyperdataset'"
|
||||
type: integer
|
||||
}
|
||||
own_tasks {
|
||||
description: "The amount of tasks under this project (without children projects). Returned if 'check_own_contents' flag is set in the request"
|
||||
type: integer
|
||||
@@ -616,6 +637,22 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.23": ${get_all_ex."2.20"} {
|
||||
request.properties {
|
||||
allow_public {
|
||||
description: "Allow public projects to be returned in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.24": ${get_all_ex."2.23"} {
|
||||
request.properties.children_type {
|
||||
description: If specified that only the projects under which the entities of this type can be found will be returned
|
||||
type: string
|
||||
enum: [pipeline, report, dataset]
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
|
||||
@@ -152,6 +152,13 @@ get_all_ex {
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
"2.21": ${get_all_ex."2.20"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden queues are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
@@ -244,6 +251,13 @@ get_all {
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
"2.21": ${get_all."2.20"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden queues are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_default {
|
||||
"2.4" {
|
||||
@@ -449,6 +463,33 @@ get_next_task {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${get_next_task."2.4"} {
|
||||
request.properties.get_task_info {
|
||||
description: "If set then additional task info is returned"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.task_info {
|
||||
description: "Info about the returned task. Returned only if get_task_info is set to True"
|
||||
type: object
|
||||
properties {
|
||||
company {
|
||||
description: Task company ID
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: ID of the user who created the task
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.21": ${get_next_task."2.14"} {
|
||||
request.properties.task {
|
||||
description: Task company ID
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
remove_task {
|
||||
"2.4" {
|
||||
|
||||
709
apiserver/schema/services/reports.conf
Normal file
709
apiserver/schema/services/reports.conf
Normal file
@@ -0,0 +1,709 @@
|
||||
_description: "Provides a management API for reports in the system."
|
||||
_definitions {
|
||||
include "_tasks_common.conf"
|
||||
include "_events_common.conf"
|
||||
update_response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of reports updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
report_status_enum {
|
||||
type: string
|
||||
enum: [
|
||||
created
|
||||
published
|
||||
]
|
||||
}
|
||||
report {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Report id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Report Name"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company ID"
|
||||
type: string
|
||||
}
|
||||
status {
|
||||
description: ""
|
||||
"$ref": "#/definitions/report_status_enum"
|
||||
}
|
||||
comment {
|
||||
description: "Free text comment"
|
||||
type: string
|
||||
}
|
||||
report {
|
||||
description: "Report template"
|
||||
type: string
|
||||
}
|
||||
report_assets {
|
||||
description: "List of the external report assets"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
created {
|
||||
description: "Report creation time (UTC) "
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
project {
|
||||
description: "Project ID of the project to which this report is assigned"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status_changed {
|
||||
description: "Last status change time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
status_message {
|
||||
description: "free text string representing info about the status"
|
||||
type: string
|
||||
}
|
||||
status_reason {
|
||||
description: "Reason for last status change"
|
||||
type: string
|
||||
}
|
||||
published {
|
||||
description: "Report publish time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last time this report was created, edited, changed"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
create {
|
||||
"2.23" {
|
||||
description: "Create a new report"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
name
|
||||
]
|
||||
properties {
|
||||
name {
|
||||
description: "Report name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
comment {
|
||||
description: "Free text comment "
|
||||
type: string
|
||||
}
|
||||
report {
|
||||
description: "Report template"
|
||||
type: string
|
||||
}
|
||||
project {
|
||||
description: "Project ID of the project to which this report is assigned Must exist[ab]"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "ID of the report"
|
||||
type: string
|
||||
}
|
||||
project_id {
|
||||
description: "ID of the project that the report belongs to"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.24": ${create."2.23"} {
|
||||
request.properties.report_assets {
|
||||
description: "List of the external report assets"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.23" {
|
||||
description: "Create a new report"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "The ID of the report task to update"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Report name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
comment {
|
||||
description: "Free text comment "
|
||||
type: string
|
||||
}
|
||||
report {
|
||||
description: "Report template"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
"2.24": ${update."2.23"} {
|
||||
request.properties.report_assets {
|
||||
description: "List of the external report assets"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
move {
|
||||
"2.23" {
|
||||
description: "Move reports to a project"
|
||||
request {
|
||||
type: object
|
||||
required: [task]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the report to move"
|
||||
type: string
|
||||
}
|
||||
project {
|
||||
description: "Target project ID. If not provided, `project_name` must be provided. Use null for the root project"
|
||||
type: string
|
||||
}
|
||||
project_name {
|
||||
description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided."
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
project_id: {
|
||||
description: The ID of the target project
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
publish {
|
||||
"2.23" {
|
||||
description: "Publish report"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "The ID of the report task to publish"
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "The client message"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
}
|
||||
}
|
||||
archive {
|
||||
"2.23" {
|
||||
description: "Archive report"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "The ID of the report task to archive"
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "The client message"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
archived {
|
||||
description: "Number of reports archived (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
unarchive {
|
||||
"2.23" {
|
||||
description: "Unarchive report"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "The ID of the report task to unarchive"
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "The client message"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
unarchived {
|
||||
description: "Number of reports unarchived (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//share {
|
||||
// "999.0" {
|
||||
// description: "Share or unshare report"
|
||||
// request {
|
||||
// type: object
|
||||
// required: [
|
||||
// task
|
||||
// ]
|
||||
// properties {
|
||||
// task {
|
||||
// description: "The ID of the report task to share/unshare"
|
||||
// type: string
|
||||
// }
|
||||
// share {
|
||||
// description: "If set to 'true' then the report will be shared. Otherwise unshared."
|
||||
// type: boolean
|
||||
// default: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// response {
|
||||
// type: object
|
||||
// properties {
|
||||
// changed {
|
||||
// description: "Number of changed reports (0 or 1)"
|
||||
// type: integer
|
||||
// enum: [0, 1]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
delete {
|
||||
"2.23" {
|
||||
description: "Delete report"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "The ID of the report task to delete"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
description: "If not set then published or unarchived reports cannot be deleted"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Number of deleted reports (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_data {
|
||||
"2.23" {
|
||||
description: "Get the tasks data according the passed search criteria + requested events"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "List of IDs to filter by"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
name {
|
||||
description: "Get only tasks whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "List of user IDs used to filter results by the task's creating user"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
type {
|
||||
description: "List of task types. One or more of: 'import', 'annotation', 'training' or 'testing' (case insensitive)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "List of task user-defined tags. Use '-' prefix to exclude tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "List of task system tags. Use '-' prefix to exclude system tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status {
|
||||
description: "List of task status."
|
||||
type: array
|
||||
items { "$ref": "#/definitions/task_status_enum" }
|
||||
}
|
||||
project {
|
||||
description: "List of project IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
only_fields {
|
||||
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
parent {
|
||||
description: "Parent ID"
|
||||
type: string
|
||||
}
|
||||
status_changed {
|
||||
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
pattern: "^(>=|>|<=|<)?.*$"
|
||||
}
|
||||
}
|
||||
search_text {
|
||||
description: "Free text search query"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "Allow public tasks to be returned in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
_any_ {
|
||||
description: "Multi-field pattern condition (any field matches pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
"input.view.entries.dataset" {
|
||||
description: "List of input dataset IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
"input.view.entries.version" {
|
||||
description: "List of input dataset version IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
search_hidden {
|
||||
description: "If set to 'true' then hidden tasks are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
plots {
|
||||
type: object
|
||||
properties {
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return plots"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images {
|
||||
type: object
|
||||
properties {
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_histogram {
|
||||
type: object
|
||||
properties {
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
description: """
|
||||
Histogram x axis to use:
|
||||
iter - iteration number
|
||||
iso_time - event time as ISO formatted string
|
||||
timestamp - event timestamp as milliseconds since epoch
|
||||
"""
|
||||
"$ref": "#/definitions/scalar_key_enum"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of tasks"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/task" }
|
||||
}
|
||||
plots {
|
||||
type: object
|
||||
description: "Plots mapped by metric, variant, task and iteration"
|
||||
additionalProperties: true
|
||||
}
|
||||
debug_images {
|
||||
type: array
|
||||
description: "Debug image events grouped by tasks and iterations"
|
||||
items {"$ref": "#/definitions/debug_images_response_task_metrics"}
|
||||
}
|
||||
scalar_metrics_iter_histogram {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
"2.23" {
|
||||
description: "Get all the company's and public report tasks"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "List of IDs to filter by"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
name {
|
||||
description: "Get only reports whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "List of user IDs used to filter results by the reports's creating user"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of reports"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
page_size {
|
||||
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
|
||||
type: integer
|
||||
minimum: 1
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "List of report user-defined tags. Use '-' prefix to exclude tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "List of report system tags. Use '-' prefix to exclude system tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status {
|
||||
description: "List of report status."
|
||||
type: array
|
||||
items { "$ref": "#/definitions/report_status_enum" }
|
||||
}
|
||||
project {
|
||||
description: "List of project IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
only_fields {
|
||||
description: "List of report field names (nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status_changed {
|
||||
description: "List of status changed constraint strings (utcformat, epoch) with an optional prefix modifier (>, >=, <, <=)"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
pattern: "^(>=|>|<=|<)?.*$"
|
||||
}
|
||||
}
|
||||
search_text {
|
||||
description: "Free text search query"
|
||||
type: string
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
allow_public {
|
||||
description: "Allow public reports to be returned in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
_any_ {
|
||||
description: "Multi-field pattern condition (any field matches pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
dependencies {
|
||||
page: [ page_size ]
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of report tasks"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/report" }
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_tags {
|
||||
"2.23" {
|
||||
description: "Get all the user tags used for the company reports"
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,7 @@ _references {
|
||||
}
|
||||
}
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
include "_tasks_common.conf"
|
||||
change_many_request: ${_definitions.batch_operation} {
|
||||
request {
|
||||
properties {
|
||||
@@ -69,374 +69,6 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
model_type_enum {
|
||||
type: string
|
||||
enum: ["input", "output"]
|
||||
}
|
||||
task_model_item {
|
||||
type: object
|
||||
required: [ name, model]
|
||||
properties {
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "The model ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
script {
|
||||
type: object
|
||||
properties {
|
||||
binary {
|
||||
description: "Binary to use when running the script"
|
||||
type: string
|
||||
default: python
|
||||
}
|
||||
repository {
|
||||
description: "Name of the repository where the script is located"
|
||||
type: string
|
||||
}
|
||||
tag {
|
||||
description: "Repository tag"
|
||||
type: string
|
||||
}
|
||||
branch {
|
||||
description: "Repository branch id If not provided and tag not provided, default repository branch is used."
|
||||
type: string
|
||||
}
|
||||
version_num {
|
||||
description: "Version (changeset) number. Optional (default is head version) Unused if tag is provided."
|
||||
type: string
|
||||
}
|
||||
entry_point {
|
||||
description: "Path to execute within the repository"
|
||||
type: string
|
||||
}
|
||||
working_dir {
|
||||
description: "Path to the folder from which to run the script Default - root folder of repository"
|
||||
type: string
|
||||
}
|
||||
requirements {
|
||||
description: "A JSON object containing requirements strings by key"
|
||||
type: object
|
||||
}
|
||||
diff {
|
||||
description: "Uncommitted changes found in the repository when task was run"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
type: object
|
||||
properties {
|
||||
destination {
|
||||
description: "Storage id. This is where output files will be stored."
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "Model id."
|
||||
type: string
|
||||
}
|
||||
result {
|
||||
description: "Task result. Values: 'success', 'failure'"
|
||||
type: string
|
||||
}
|
||||
error {
|
||||
description: "Last error text"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_execution_progress_enum {
|
||||
type: string
|
||||
enum: [
|
||||
unknown
|
||||
running
|
||||
stopping
|
||||
stopped
|
||||
]
|
||||
}
|
||||
output_rois_enum {
|
||||
type: string
|
||||
enum: [
|
||||
all_in_frame
|
||||
only_filtered
|
||||
frame_per_roi
|
||||
]
|
||||
}
|
||||
artifact_type_data {
|
||||
type: object
|
||||
properties {
|
||||
preview {
|
||||
description: "Description or textual data"
|
||||
type: string
|
||||
}
|
||||
content_type {
|
||||
description: "System defined raw data content type"
|
||||
type: string
|
||||
}
|
||||
data_hash {
|
||||
description: "Hash of raw data, without any headers or descriptive parts"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact_mode_enum {
|
||||
type: string
|
||||
enum: [
|
||||
input
|
||||
output
|
||||
]
|
||||
default: output
|
||||
}
|
||||
artifact {
|
||||
type: object
|
||||
required: [key, type]
|
||||
properties {
|
||||
key {
|
||||
description: "Entry key"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "System defined type"
|
||||
type: string
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
"$ref": "#/definitions/artifact_mode_enum"
|
||||
}
|
||||
uri {
|
||||
description: "Raw data location"
|
||||
type: string
|
||||
}
|
||||
content_size {
|
||||
description: "Raw data length in bytes"
|
||||
type: integer
|
||||
}
|
||||
hash {
|
||||
description: "Hash of entire raw data"
|
||||
type: string
|
||||
}
|
||||
timestamp {
|
||||
description: "Epoch time when artifact was created"
|
||||
type: integer
|
||||
}
|
||||
type_data {
|
||||
description: "Additional fields defined by the system"
|
||||
"$ref": "#/definitions/artifact_type_data"
|
||||
}
|
||||
display_data {
|
||||
description: "User-defined list of key/value pairs, sorted"
|
||||
type: array
|
||||
items {
|
||||
type: array
|
||||
items {
|
||||
type: string # can also be a number... TODO: upgrade the generator
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact_id {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
key {
|
||||
description: "Entry key"
|
||||
type: string
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
"$ref": "#/definitions/artifact_mode_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
task_models {
|
||||
type: object
|
||||
properties {
|
||||
input {
|
||||
description: "The list of task input models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
|
||||
}
|
||||
output {
|
||||
description: "The list of task output models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
queue {
|
||||
description: "Queue ID where task was queued."
|
||||
type: string
|
||||
}
|
||||
parameters {
|
||||
description: "Json object containing the Task parameters"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
model {
|
||||
description: "Execution input model ID Not applicable for Register (Import) tasks"
|
||||
type: string
|
||||
}
|
||||
model_desc {
|
||||
description: "Json object representing the Model descriptors"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
model_labels {
|
||||
description: """Json object representing the ids of the labels in the model.
|
||||
The keys are the layers' names and the values are the IDs.
|
||||
Not applicable for Register (Import) tasks.
|
||||
Mandatory for Training tasks"""
|
||||
type: object
|
||||
additionalProperties: { type: integer }
|
||||
}
|
||||
framework {
|
||||
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
|
||||
type: string
|
||||
}
|
||||
docker_cmd {
|
||||
description: "Command for running docker script for the execution of the task"
|
||||
type: string
|
||||
}
|
||||
artifacts {
|
||||
description: "Task artifacts"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/artifact" }
|
||||
}
|
||||
}
|
||||
}
|
||||
task_status_enum {
|
||||
type: string
|
||||
enum: [
|
||||
created
|
||||
queued
|
||||
in_progress
|
||||
stopped
|
||||
published
|
||||
publishing
|
||||
closed
|
||||
failed
|
||||
completed
|
||||
unknown
|
||||
]
|
||||
}
|
||||
task_type_enum {
|
||||
type: string
|
||||
enum: [
|
||||
training
|
||||
testing
|
||||
inference
|
||||
data_processing
|
||||
application
|
||||
monitor
|
||||
controller
|
||||
optimizer
|
||||
service
|
||||
qc
|
||||
custom
|
||||
]
|
||||
}
|
||||
last_metrics_event {
|
||||
type: object
|
||||
properties {
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Variant name"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Last value reported"
|
||||
type: number
|
||||
}
|
||||
min_value {
|
||||
description: "Minimum value reported"
|
||||
type: number
|
||||
}
|
||||
max_value {
|
||||
description: "Maximum value reported"
|
||||
type: number
|
||||
}
|
||||
}
|
||||
}
|
||||
last_metrics_variants {
|
||||
type: object
|
||||
description: "Last metric events, one for each variant hash"
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/last_metrics_event"
|
||||
}
|
||||
}
|
||||
params_item {
|
||||
type: object
|
||||
properties {
|
||||
section {
|
||||
description: "Section that the parameter belongs to"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name of the parameter. The combination of section and name should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
configuration_item {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Name of the parameter. Should be unique"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: "Value of the parameter"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of the parameter. Optional"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "The parameter description. Optional"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
param_key {
|
||||
type: object
|
||||
properties {
|
||||
@@ -450,13 +82,6 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
section_params {
|
||||
description: "Task section params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/params_item"
|
||||
}
|
||||
}
|
||||
replace_hyperparams_enum {
|
||||
type: string
|
||||
enum: [
|
||||
@@ -465,165 +90,6 @@ _definitions {
|
||||
all
|
||||
]
|
||||
}
|
||||
task {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Task Name"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company ID"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Type of task. Values: 'training', 'testing'"
|
||||
"$ref": "#/definitions/task_type_enum"
|
||||
}
|
||||
status {
|
||||
description: ""
|
||||
"$ref": "#/definitions/task_status_enum"
|
||||
}
|
||||
comment {
|
||||
description: "Free text comment"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Task creation time (UTC) "
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
started {
|
||||
description: "Task start time (UTC)"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
completed {
|
||||
description: "Task end time (UTC)"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
active_duration {
|
||||
description: "Task duration time (seconds)"
|
||||
type: integer
|
||||
}
|
||||
parent {
|
||||
description: "Parent task id"
|
||||
type: string
|
||||
}
|
||||
project {
|
||||
description: "Project ID of the project to which this task is assigned"
|
||||
type: string
|
||||
}
|
||||
output {
|
||||
description: "Task output params"
|
||||
"$ref": "#/definitions/output"
|
||||
}
|
||||
execution {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
// TODO: will be removed
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
status_changed {
|
||||
description: "Last status change time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
status_message {
|
||||
description: "free text string representing info about the status"
|
||||
type: string
|
||||
}
|
||||
status_reason {
|
||||
description: "Reason for last status change"
|
||||
type: string
|
||||
}
|
||||
published {
|
||||
description: "Last status change time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_worker {
|
||||
description: "ID of last worker that handled the task"
|
||||
type: string
|
||||
}
|
||||
last_worker_report {
|
||||
description: "Last time a worker reported while working on this task"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last time this task was created, updated, changed or events for this task were reported"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_change {
|
||||
description: "Last time any update was done to the task"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_iteration {
|
||||
description: "Last iteration reported for this task"
|
||||
type: integer
|
||||
}
|
||||
last_metrics {
|
||||
description: "Last metric variants (hash to events), one for each metric hash"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/last_metrics_variants"
|
||||
}
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/section_params"
|
||||
}
|
||||
}
|
||||
configuration {
|
||||
description: "Task configuration params"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
task_urls {
|
||||
type: object
|
||||
properties {
|
||||
@@ -715,6 +181,15 @@ get_all_ex {
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.23": ${get_all_ex."2.15"} {
|
||||
request.properties {
|
||||
allow_public {
|
||||
description: "Allow public tasks to be returned in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -1227,6 +702,10 @@ validate {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
@@ -1241,10 +720,6 @@ validate {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -1392,6 +867,10 @@ edit {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
hyperparams {
|
||||
description: "Task hyper params per section"
|
||||
type: object
|
||||
@@ -1406,10 +885,6 @@ edit {
|
||||
"$ref": "#/definitions/configuration_item"
|
||||
}
|
||||
}
|
||||
script {
|
||||
description: "Script info"
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
}
|
||||
}
|
||||
response: ${_definitions.update_response}
|
||||
@@ -1478,6 +953,10 @@ reset {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
delete_output_models {
|
||||
description: "If set to 'true' then delete output models of this task that are not referenced by other tasks. Default value is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -1489,6 +968,13 @@ reset {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.21": ${reset."2.13"} {
|
||||
request.properties.delete_external_artifacts {
|
||||
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the task from the fileserver (if configured to do so)"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
reset_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
@@ -1541,6 +1027,13 @@ reset_many {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.21": ${reset_many."2.13"} {
|
||||
request.properties.delete_external_artifacts {
|
||||
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the tasks from the fileserver (if configured to do so)"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_many {
|
||||
"2.13": ${_definitions.batch_operation} {
|
||||
@@ -1591,6 +1084,13 @@ delete_many {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.21": ${delete_many."2.13"} {
|
||||
request.properties.delete_external_artifacts {
|
||||
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the tasks from the fileserver (if configured to do so)"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
@@ -1644,6 +1144,10 @@ delete {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
delete_output_models {
|
||||
description: "If set to 'true' then delete output models of this task that are not referenced by other tasks. Default value is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -1655,6 +1159,13 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.21": ${delete."2.13"} {
|
||||
request.properties.delete_external_artifacts {
|
||||
description: "If set to 'true' then BE will try to delete the extenal artifacts associated with the task from the fileserver (if configured to do so)"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
archive {
|
||||
"2.12" {
|
||||
@@ -1708,12 +1219,12 @@ archive_many {
|
||||
type: string
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.archived {
|
||||
description: "Indicates whether the task was archived"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
succeeded.items.properties.archived {
|
||||
description: "Indicates whether the task was archived"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1920,6 +1431,17 @@ Fails if the following parameters in the task were not filled:
|
||||
type: string
|
||||
}
|
||||
}
|
||||
"2.22": ${enqueue."2.19"} {
|
||||
request.properties.verify_watched_queue {
|
||||
description: If passed then check wheter there are any workers watiching the queue
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.queue_watched {
|
||||
description: Returns true if there are workers or autscalers working with the queue
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
enqueue_many {
|
||||
"2.13": ${_definitions.change_many_request} {
|
||||
@@ -1953,6 +1475,17 @@ enqueue_many {
|
||||
type: string
|
||||
}
|
||||
}
|
||||
"2.22": ${enqueue_many."2.19"} {
|
||||
request.properties.verify_watched_queue {
|
||||
description: If passed then check wheter there are any workers watiching the queue
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.queue_watched {
|
||||
description: Returns true if there are workers or autscalers working with the queue
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
dequeue {
|
||||
"1.5" {
|
||||
@@ -2465,7 +1998,7 @@ move {
|
||||
items { type: string }
|
||||
}
|
||||
project {
|
||||
description: "Target project ID. If not provided, `project_name` must be provided."
|
||||
description: "Target project ID. If not provided, `project_name` must be provided. Use null for the root project"
|
||||
type: string
|
||||
}
|
||||
project_name {
|
||||
|
||||
@@ -152,6 +152,15 @@ _definitions {
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
key {
|
||||
description: "Worker entry key"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,11 +168,11 @@ _definitions {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "ID"
|
||||
description: "Worker ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name"
|
||||
description: "Worker name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
@@ -294,6 +303,13 @@ get_all {
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
"2.22": ${get_all."2.20"} {
|
||||
request.properties.system_tags {
|
||||
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
register {
|
||||
"2.4" {
|
||||
@@ -328,6 +344,13 @@ register {
|
||||
properties {}
|
||||
}
|
||||
}
|
||||
"2.22": ${register."2.4"} {
|
||||
request.properties.system_tags {
|
||||
description: "System tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
unregister {
|
||||
"2.4" {
|
||||
@@ -395,6 +418,13 @@ status_report {
|
||||
properties {}
|
||||
}
|
||||
}
|
||||
"2.22": ${status_report."2.4"} {
|
||||
request.properties.system_tags {
|
||||
description: "New system tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_metric_keys {
|
||||
"2.4" {
|
||||
|
||||
@@ -6,6 +6,8 @@ from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from packaging.version import Version
|
||||
|
||||
from apiserver.bll.queue.queue_metrics import MetricsRefresher
|
||||
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
||||
from apiserver.database import db
|
||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||
from apiserver.config import info
|
||||
@@ -26,7 +28,6 @@ from apiserver.service_repo import ServiceRepo
|
||||
from apiserver.sync import distributed_lock
|
||||
from apiserver.updates import check_updates_thread
|
||||
from apiserver.utilities.env import get_bool
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -119,6 +120,8 @@ class AppSequence:
|
||||
def _start_worker(self):
|
||||
check_updates_thread.start()
|
||||
StatisticsReporter.start()
|
||||
MetricsRefresher.start()
|
||||
NonResponsiveTasksWatchdog.start()
|
||||
|
||||
def _on_worker_stop(self):
|
||||
ThreadsManager.terminating = True
|
||||
pass
|
||||
|
||||
@@ -11,7 +11,6 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Entities, Credentials
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.timing_context import TimingContext
|
||||
from .fixed_user import FixedUser
|
||||
from .identity import Identity
|
||||
from .payload import Payload, Token, Basic, AuthType
|
||||
@@ -88,9 +87,7 @@ def authorize_credentials(auth_data, service, action, call):
|
||||
|
||||
query = Q(id=fixed_user.user_id)
|
||||
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context(
|
||||
"authorizing request"
|
||||
):
|
||||
with translate_errors_context("authorizing request"):
|
||||
user = User.objects(query).first()
|
||||
if not user:
|
||||
raise errors.unauthorized.InvalidCredentials(
|
||||
@@ -108,8 +105,7 @@ def authorize_credentials(auth_data, service, action, call):
|
||||
}
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "company_by_id"):
|
||||
company = Company.objects(id=user.company).only("id", "name").first()
|
||||
company = Company.objects(id=user.company).only("id", "name").first()
|
||||
|
||||
if not company:
|
||||
raise errors.unauthorized.InvalidCredentials("invalid user company")
|
||||
|
||||
@@ -39,7 +39,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.20")
|
||||
_max_version = PartialVersion("2.24")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
|
||||
@@ -2,7 +2,7 @@ import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional
|
||||
from typing import Sequence, Optional, Union, Tuple, Mapping
|
||||
|
||||
import attr
|
||||
import jsonmodels.fields
|
||||
@@ -19,7 +19,6 @@ from apiserver.apimodels.events import (
|
||||
TaskMetricsRequest,
|
||||
LogEventsRequest,
|
||||
LogOrderEnum,
|
||||
GetHistorySampleRequest,
|
||||
NextHistorySampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
@@ -28,18 +27,44 @@ from apiserver.apimodels.events import (
|
||||
ClearScrollRequest,
|
||||
ClearTaskLogRequest,
|
||||
SingleValueMetricsRequest,
|
||||
GetVariantSampleRequest,
|
||||
GetMetricSamplesRequest,
|
||||
TaskMetric,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||
from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json, extract_properties_to_lists
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
model_bll = ModelBLL()
|
||||
|
||||
|
||||
def _assert_task_or_model_exists(
|
||||
company_id: str, task_ids: Union[str, Sequence[str]], model_events: bool
|
||||
) -> Union[Sequence[Model], Sequence[Task]]:
|
||||
if model_events:
|
||||
return model_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids,
|
||||
allow_public=True,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
)
|
||||
|
||||
return task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids,
|
||||
allow_public=True,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.add")
|
||||
@@ -47,7 +72,7 @@ def add(call: APICall, company_id, _):
|
||||
data = call.data.copy()
|
||||
allow_locked = data.pop("allow_locked", False)
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id, [data], call.worker, allow_locked_tasks=allow_locked
|
||||
company_id, [data], call.worker, allow_locked=allow_locked
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
@@ -58,7 +83,12 @@ def add_batch(call: APICall, company_id, _):
|
||||
if events is None or len(events) == 0:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id,
|
||||
events,
|
||||
call.worker,
|
||||
allow_locked=events[0].get("allow_locked", False),
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@@ -225,12 +255,13 @@ def download_task_log(call, company_id, _):
|
||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
||||
def get_vector_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
model_events = call.data["model_events"]
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.get_index_company(), task_id, EventType.metrics_vector
|
||||
task_or_model.get_index_company(), task_id, EventType.metrics_vector
|
||||
)
|
||||
)
|
||||
|
||||
@@ -238,12 +269,13 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
model_events = call.data["model_events"]
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.get_index_company(), task_id, EventType.metrics_scalar
|
||||
task_or_model.get_index_company(), task_id, EventType.metrics_scalar
|
||||
)
|
||||
)
|
||||
|
||||
@@ -255,13 +287,14 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
)
|
||||
def vector_metrics_iter_histogram(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
model_events = call.data["model_events"]
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
task.get_index_company(), task_id, metric, variant
|
||||
task_or_model.get_index_company(), task_id, metric, variant
|
||||
)
|
||||
call.result.data = dict(
|
||||
metric=metric, variant=variant, vectors=vectors, iterations=iterations
|
||||
@@ -286,11 +319,10 @@ def make_response(
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
|
||||
def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
def get_task_events(_, company_id, request: TaskEventsRequest):
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=request.model_events,
|
||||
)[0]
|
||||
|
||||
key = ScalarKeyEnum.iter
|
||||
@@ -322,7 +354,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
@@ -336,7 +368,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
key=ScalarKeyEnum.iter,
|
||||
@@ -365,16 +397,17 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_or_model.get_index_company(),
|
||||
task_id,
|
||||
event_type=EventType.metrics_scalar,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
metrics={metric: []},
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
@@ -398,7 +431,7 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
index_company, task_id
|
||||
)
|
||||
last_iters = event_bll.get_last_iters(
|
||||
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||
company_id=index_company, event_type=EventType.all, task_id=task_id, iters=1
|
||||
).get(task_id)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
@@ -417,36 +450,60 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
def scalar_metrics_iter_histogram(
|
||||
call, company_id, request: ScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, request.task, allow_public=True, only=("company", "company_origin")
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events
|
||||
)[0]
|
||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||
task.get_index_company(),
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=request.task,
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
call.result.data = metrics
|
||||
|
||||
|
||||
def _get_task_or_model_index_companies(
|
||||
company_id: str, task_ids: Sequence[str], model_events=False,
|
||||
) -> TaskCompanies:
|
||||
"""
|
||||
Returns lists of tasks grouped by company
|
||||
"""
|
||||
tasks_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids, model_events=model_events,
|
||||
)
|
||||
|
||||
unique_ids = set(task_ids)
|
||||
if len(tasks_or_models) < len(unique_ids):
|
||||
invalid = tuple(unique_ids - {t.id for t in tasks_or_models})
|
||||
error_cls = (
|
||||
errors.bad_request.InvalidModelId
|
||||
if model_events
|
||||
else errors.bad_request.InvalidTaskId
|
||||
)
|
||||
raise error_cls(company=company_id, ids=invalid)
|
||||
|
||||
return bucketize(tasks_or_models, key=lambda t: t.get_index_company())
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.multi_task_scalar_metrics_iter_histogram",
|
||||
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
|
||||
)
|
||||
def multi_task_scalar_metrics_iter_histogram(
|
||||
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
|
||||
call, company_id, request: MultiTaskScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task_ids = req_model.tasks
|
||||
task_ids = request.tasks
|
||||
if isinstance(task_ids, str):
|
||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||
# Note, bll already validates task ids as it needs their names
|
||||
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
samples=req_model.samples,
|
||||
allow_public=True,
|
||||
key=req_model.key,
|
||||
companies=_get_task_or_model_index_companies(
|
||||
company_id, task_ids, request.model_events
|
||||
),
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -455,21 +512,11 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
def get_task_single_value_metrics(
|
||||
call, company_id: str, request: SingleValueMetricsRequest
|
||||
):
|
||||
task_ids = call.data["tasks"]
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
res = event_bll.metrics.get_task_single_value_metrics(
|
||||
companies=_get_task_or_model_index_companies(
|
||||
company_id, request.tasks, request.model_events
|
||||
),
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids)
|
||||
call.result.data = dict(
|
||||
tasks=[{"task": task, "values": values} for task, values in res.items()]
|
||||
)
|
||||
@@ -481,22 +528,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
next(iter(companies)),
|
||||
list(companies),
|
||||
task_ids,
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -504,10 +540,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, tasks=tasks
|
||||
result.events, max_iters=iters, task_names=task_names
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@@ -518,47 +555,56 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
def _get_multitask_plots(
|
||||
companies: TaskCompanies,
|
||||
last_iters: int,
|
||||
metrics: MetricVariants = None,
|
||||
scroll_id=None,
|
||||
no_scroll=True,
|
||||
model_events=False,
|
||||
) -> Tuple[dict, int, str]:
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
result = event_bll.get_task_events(
|
||||
company_id=list(companies),
|
||||
task_id=list(task_names),
|
||||
event_type=EventType.metrics_plot,
|
||||
metrics=metrics,
|
||||
last_iter_count=last_iters,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=last_iters, task_names=task_names
|
||||
)
|
||||
return return_events, result.total_events, result.next_scroll_id
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
|
||||
def get_multi_task_plots(call, company_id, req_model):
|
||||
def get_multi_task_plots(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
companies = _get_task_or_model_index_companies(
|
||||
company_id, task_ids, model_events=model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.get_task_events(
|
||||
next(iter(companies)),
|
||||
task_ids,
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
||||
companies=companies,
|
||||
last_iters=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
model_events=model_events,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, tasks=tasks
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
total=total_events,
|
||||
scroll_id=next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -613,18 +659,14 @@ def _get_metric_variants_from_request(
|
||||
def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=request.model_events
|
||||
)[0]
|
||||
result = event_bll.get_task_plots(
|
||||
task.get_index_company(),
|
||||
tasks=[task_id],
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
|
||||
@@ -638,34 +680,43 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
)
|
||||
|
||||
|
||||
def _task_metrics_dict_from_request(req_metrics: Sequence[TaskMetric]) -> dict:
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in req_metrics:
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
return task_metrics
|
||||
|
||||
|
||||
def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEvents]:
|
||||
return [
|
||||
MetricEvents(
|
||||
task=task,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, iterations) in metric_events
|
||||
]
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.plots",
|
||||
request_data_model=MetricEventsRequest,
|
||||
response_data_model=MetricEventsResponse,
|
||||
)
|
||||
def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
||||
task_ids = list(task_metrics)
|
||||
task_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids=task_ids, model_events=request.model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.plots_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
companies={t.id: t.get_index_company() for t in task_or_models},
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@@ -675,16 +726,7 @@ def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
|
||||
call.result.data_model = MetricEventsResponse(
|
||||
scroll_id=result.next_scroll_id,
|
||||
metrics=[
|
||||
MetricEvents(
|
||||
task=task,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, iterations) in result.metric_events
|
||||
],
|
||||
metrics=_get_metrics_response(result.metric_events),
|
||||
)
|
||||
|
||||
|
||||
@@ -730,12 +772,13 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
tasks_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
tasks_or_model.get_index_company(),
|
||||
task_id,
|
||||
event_type=EventType.metrics_image,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@@ -761,28 +804,13 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
response_data_model=MetricEventsResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
||||
task_ids = list(task_metrics)
|
||||
task_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids=task_ids, model_events=request.model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
companies={t.id: t.get_index_company() for t in task_or_models},
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@@ -792,30 +820,21 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||
|
||||
call.result.data_model = MetricEventsResponse(
|
||||
scroll_id=result.next_scroll_id,
|
||||
metrics=[
|
||||
MetricEvents(
|
||||
task=task,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, iterations) in result.metric_events
|
||||
],
|
||||
metrics=_get_metrics_response(result.metric_events),
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_debug_image_sample",
|
||||
min_version="2.12",
|
||||
request_data_model=GetHistorySampleRequest,
|
||||
request_data_model=GetVariantSampleRequest,
|
||||
)
|
||||
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.debug_image_sample_history.get_sample_for_variant(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
@@ -833,30 +852,30 @@ def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
request_data_model=NextHistorySampleRequest,
|
||||
)
|
||||
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.debug_image_sample_history.get_next_sample(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
state_id=request.scroll_id,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
next_iteration=request.next_iteration,
|
||||
)
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
|
||||
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
|
||||
)
|
||||
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_sample_for_variant(
|
||||
company_id=task.company,
|
||||
res = event_bll.plot_sample_history.get_samples_for_metric(
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
iteration=request.iteration,
|
||||
refresh=request.refresh,
|
||||
state_id=request.scroll_id,
|
||||
@@ -869,28 +888,28 @@ def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
|
||||
)
|
||||
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_next_sample(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
state_id=request.scroll_id,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
next_iteration=request.next_iteration,
|
||||
)
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
||||
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)[0]
|
||||
task_or_models = _assert_task_or_model_exists(
|
||||
company_id, request.tasks, model_events=request.model_events,
|
||||
)
|
||||
res = event_bll.metrics.get_task_metrics(
|
||||
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
|
||||
task_or_models[0].get_index_company(),
|
||||
task_ids=request.tasks,
|
||||
event_type=request.event_type,
|
||||
)
|
||||
call.result.data = {
|
||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||
@@ -898,7 +917,7 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, req_model):
|
||||
def delete_for_task(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
@@ -910,6 +929,19 @@ def delete_for_task(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.delete_for_model", required_fields=["model"])
|
||||
def delete_for_model(call: APICall, company_id: str, _):
|
||||
model_id = call.data["model"]
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
model_bll.assert_exists(company_id, model_id, return_models=False)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(
|
||||
company_id, model_id, allow_locked=allow_locked, model=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.clear_task_log")
|
||||
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
|
||||
task_id = request.task
|
||||
@@ -925,7 +957,9 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
|
||||
)
|
||||
|
||||
|
||||
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
def _get_top_iter_unique_events_per_task(
|
||||
events, max_iters: int, task_names: Mapping[str, str]
|
||||
):
|
||||
key = itemgetter("metric", "variant", "task", "iter")
|
||||
|
||||
unique_events = itertools.chain.from_iterable(
|
||||
@@ -938,7 +972,7 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
def collect(evs, fields):
|
||||
if not fields:
|
||||
evs = list(evs)
|
||||
return {"name": tasks.get(evs[0].get("task")), "plots": evs}
|
||||
return {"name": task_names.get(evs[0].get("task")), "plots": evs}
|
||||
return {
|
||||
str(k): collect(group, fields[1:])
|
||||
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
|
||||
@@ -1004,17 +1038,15 @@ def scalar_metrics_iter_raw(
|
||||
request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=request.model_events
|
||||
)[0]
|
||||
|
||||
metric_variants = _get_metric_variants_from_request([request.metric])
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
@@ -1030,7 +1062,7 @@ def scalar_metrics_iter_raw(
|
||||
for iteration in range(0, math.ceil(batch_size / 10_000)):
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=min(batch_size, 10_000),
|
||||
navigate_earlier=False,
|
||||
|
||||
@@ -24,7 +24,7 @@ from apiserver.apimodels.models import (
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
@@ -51,8 +51,8 @@ from apiserver.services.utils import (
|
||||
ModelsBackwardsCompatibility,
|
||||
unescape_metadata,
|
||||
escape_metadata,
|
||||
process_include_subprojects,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@@ -107,30 +107,18 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
_process_include_subprojects(call.data)
|
||||
process_include_subprojects(call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
|
||||
@@ -139,10 +127,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
return
|
||||
|
||||
model_ids = {model["id"] for model in models}
|
||||
stats = ModelBLL.get_model_stats(
|
||||
company=company_id,
|
||||
model_ids=list(model_ids),
|
||||
)
|
||||
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
|
||||
|
||||
for model in models:
|
||||
model["stats"] = stats.get(model["id"])
|
||||
@@ -154,10 +139,9 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models}
|
||||
@@ -167,15 +151,14 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
@@ -414,9 +397,7 @@ def validate_task(company_id, fields: dict):
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
@@ -424,11 +405,7 @@ def edit(call: APICall, company_id, _):
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
if field and isinstance(value, dict) and isinstance(field, EmbeddedDocument):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
@@ -448,13 +425,9 @@ def edit(call: APICall, company_id, _):
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
@@ -465,9 +438,7 @@ def edit(call: APICall, company_id, _):
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
@@ -511,6 +482,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
|
||||
updated, published_task = ModelBLL.publish_model(
|
||||
model_id=request.model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
)
|
||||
@@ -529,6 +501,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
|
||||
func=partial(
|
||||
ModelBLL.publish_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
),
|
||||
@@ -635,7 +608,7 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
|
||||
@endpoint("models.move", request_data_model=MoveRequest)
|
||||
def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
if not (request.project or request.project_name):
|
||||
if not ("project" in call.data or request.project_name):
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"project or project_name is required"
|
||||
)
|
||||
|
||||
@@ -6,14 +6,16 @@ from mongoengine import Q
|
||||
|
||||
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.database.model import User, AttributedDocument, EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.task.task import Task, TaskType
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
@endpoint("organization.get_tags", request_data_model=TagsRequest)
|
||||
@@ -49,14 +51,15 @@ def get_user_companies(call: APICall, company_id: str, _):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("organization.get_entities_count", request_data_model=EntitiesCountRequest)
|
||||
def get_entities_count(call: APICall, company, _):
|
||||
@endpoint("organization.get_entities_count")
|
||||
def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
|
||||
entity_classes: Mapping[str, Type[AttributedDocument]] = {
|
||||
"projects": Project,
|
||||
"tasks": Task,
|
||||
"models": Model,
|
||||
"pipelines": Project,
|
||||
"datasets": Project,
|
||||
"reports": Task,
|
||||
}
|
||||
ret = {}
|
||||
for field, entity_cls in entity_classes.items():
|
||||
@@ -64,12 +67,41 @@ def get_entities_count(call: APICall, company, _):
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
if field == "reports":
|
||||
data["type"] = TaskType.report
|
||||
data["include_subprojects"] = True
|
||||
|
||||
if request.active_users:
|
||||
if entity_cls is Project:
|
||||
requested_ids = data.get("id")
|
||||
if isinstance(requested_ids, str):
|
||||
requested_ids = [requested_ids]
|
||||
ids, _ = project_bll.get_projects_with_selected_children(
|
||||
company=company,
|
||||
users=request.active_users,
|
||||
project_ids=requested_ids,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
if not ids:
|
||||
ret[field] = 0
|
||||
continue
|
||||
data["id"] = ids
|
||||
elif not data.get("user"):
|
||||
data["user"] = request.active_users
|
||||
|
||||
query = Q()
|
||||
if entity_cls in (Project, Task) and not data.get("search_hidden"):
|
||||
if (
|
||||
entity_cls in (Project, Task)
|
||||
and field not in ("reports", "pipelines", "datasets")
|
||||
and not request.search_hidden
|
||||
):
|
||||
query &= Q(system_tags__ne=EntityVisibility.hidden.value)
|
||||
|
||||
ret[field] = entity_cls.get_count(
|
||||
company=company, query_dict=data, query=query, allow_public=True,
|
||||
company=company,
|
||||
query_dict=data,
|
||||
query=query,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
|
||||
call.result.data = ret
|
||||
|
||||
@@ -60,6 +60,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from mongoengine import Q
|
||||
@@ -18,9 +18,11 @@ from apiserver.apimodels.projects import (
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
ProjectModelMetadataValuesRequest,
|
||||
ProjectChildrenType,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
from apiserver.bll.project.project_bll import pipeline_tag, reports_tag
|
||||
from apiserver.bll.project.project_cleanup import (
|
||||
delete_project,
|
||||
validate_project_delete,
|
||||
@@ -28,6 +30,7 @@ from apiserver.bll.project.project_cleanup import (
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import TaskType
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
get_company_or_none_constraint,
|
||||
@@ -39,7 +42,6 @@ from apiserver.services.utils import (
|
||||
get_tags_filter_dictionary,
|
||||
sort_tags_response,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
@@ -60,11 +62,8 @@ def get_by_id(call):
|
||||
project_id = call.data["project"]
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "projects_by_id"):
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(
|
||||
call.identity.company
|
||||
)
|
||||
project = Project.objects(query).first()
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
|
||||
project = Project.objects(query).first()
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
@@ -100,80 +99,135 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
|
||||
data["parent"] = [None]
|
||||
|
||||
|
||||
def _get_project_stats_filter(request: ProjectsGetRequest) -> Tuple[Optional[dict], bool]:
|
||||
if request.include_stats_filter or not request.children_type:
|
||||
return request.include_stats_filter, request.search_hidden
|
||||
|
||||
if request.children_type == ProjectChildrenType.pipeline:
|
||||
return {"system_tags": [pipeline_tag], "type": [TaskType.controller]}, True
|
||||
if request.children_type == ProjectChildrenType.report:
|
||||
return {"system_tags": [reports_tag], "type": [TaskType.report]}, True
|
||||
|
||||
return request.include_stats_filter, request.search_hidden
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
data = call.data
|
||||
conform_tag_fields(call, data)
|
||||
allow_public = not request.non_public
|
||||
allow_public = (
|
||||
data["allow_public"]
|
||||
if "allow_public" in data
|
||||
else not data["non_public"]
|
||||
if "non_public" in data
|
||||
else request.allow_public
|
||||
)
|
||||
|
||||
requested_ids = data.get("id")
|
||||
if isinstance(requested_ids, str):
|
||||
requested_ids = [requested_ids]
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=request.shallow_search,
|
||||
)
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
user_active_project_ids = None
|
||||
if request.active_users:
|
||||
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
|
||||
company=company_id,
|
||||
users=request.active_users,
|
||||
project_ids=requested_ids,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
if not ids:
|
||||
return {"projects": []}
|
||||
data["id"] = ids
|
||||
|
||||
ret_params = {}
|
||||
projects: Sequence[dict] = Project.get_many_with_join(
|
||||
selected_project_ids = None
|
||||
if request.active_users or request.children_type:
|
||||
ids, selected_project_ids = project_bll.get_projects_with_selected_children(
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
|
||||
users=request.active_users,
|
||||
project_ids=requested_ids,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
children_type=request.children_type,
|
||||
)
|
||||
if not ids:
|
||||
return {"projects": []}
|
||||
data["id"] = ids
|
||||
|
||||
if request.check_own_contents and requested_ids:
|
||||
existing_requested_ids = {
|
||||
project["id"] for project in projects if project["id"] in requested_ids
|
||||
}
|
||||
if existing_requested_ids:
|
||||
contents = project_bll.calc_own_contents(
|
||||
company=company_id,
|
||||
project_ids=list(existing_requested_ids),
|
||||
filter_=request.include_stats_filter,
|
||||
users=request.active_users,
|
||||
)
|
||||
for project in projects:
|
||||
project.update(**contents.get(project["id"], {}))
|
||||
ret_params = {}
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
if request.include_stats:
|
||||
project_ids = {project["id"] for project in projects}
|
||||
stats, children = project_bll.get_project_stats(
|
||||
remove_system_tags = False
|
||||
if request.search_hidden:
|
||||
only_fields = data.get("only_fields")
|
||||
if isinstance(only_fields, list) and "system_tags" not in only_fields:
|
||||
only_fields.append("system_tags")
|
||||
remove_system_tags = True
|
||||
|
||||
projects: Sequence[dict] = Project.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
if not projects:
|
||||
return {"projects": projects, **ret_params}
|
||||
|
||||
if request.search_hidden:
|
||||
for p in projects:
|
||||
system_tags = (
|
||||
p.pop("system_tags", [])
|
||||
if remove_system_tags
|
||||
else p.get("system_tags", [])
|
||||
)
|
||||
if EntityVisibility.hidden.value in system_tags:
|
||||
p["hidden"] = True
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
project_ids = list({project["id"] for project in projects})
|
||||
|
||||
if request.check_own_contents:
|
||||
if request.children_type == ProjectChildrenType.dataset:
|
||||
contents = project_bll.calc_own_datasets(
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
include_children=request.stats_with_children,
|
||||
search_hidden=request.search_hidden,
|
||||
project_ids=project_ids,
|
||||
filter_=request.include_stats_filter,
|
||||
users=request.active_users,
|
||||
user_active_project_ids=user_active_project_ids,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
if request.include_dataset_stats:
|
||||
project_ids = {project["id"] for project in projects}
|
||||
dataset_stats = project_bll.get_dataset_stats(
|
||||
else:
|
||||
contents = project_bll.calc_own_contents(
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
project_ids=project_ids,
|
||||
filter_=_get_project_stats_filter(request)[0],
|
||||
users=request.active_users,
|
||||
)
|
||||
for project in projects:
|
||||
project["dataset_stats"] = dataset_stats.get(project["id"])
|
||||
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
for project in projects:
|
||||
project.update(**contents.get(project["id"], {}))
|
||||
|
||||
if request.include_stats:
|
||||
if request.children_type == ProjectChildrenType.dataset:
|
||||
stats, children = project_bll.get_project_dataset_stats(
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
include_children=request.stats_with_children,
|
||||
filter_=request.include_stats_filter,
|
||||
users=request.active_users,
|
||||
selected_project_ids=selected_project_ids,
|
||||
)
|
||||
else:
|
||||
filter_, search_hidden = _get_project_stats_filter(request)
|
||||
stats, children = project_bll.get_project_stats(
|
||||
company=company_id,
|
||||
project_ids=project_ids,
|
||||
specific_state=request.stats_for_state,
|
||||
include_children=request.stats_with_children,
|
||||
search_hidden=search_hidden,
|
||||
filter_=filter_,
|
||||
users=request.active_users,
|
||||
selected_project_ids=selected_project_ids,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
if request.include_dataset_stats:
|
||||
dataset_stats = project_bll.get_dataset_stats(
|
||||
company=company_id, project_ids=project_ids, users=request.active_users,
|
||||
)
|
||||
for project in projects:
|
||||
project["dataset_stats"] = dataset_stats.get(project["id"])
|
||||
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
@@ -183,20 +237,19 @@ def get_all(call: APICall):
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=data.get("shallow_search", False),
|
||||
)
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=data,
|
||||
query=_hidden_query(
|
||||
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
||||
),
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=data,
|
||||
query=_hidden_query(
|
||||
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
||||
),
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -282,6 +335,7 @@ def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
|
||||
def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
res, affected_projects = delete_project(
|
||||
company=company_id,
|
||||
user=call.identity.user,
|
||||
project_id=request.project,
|
||||
force=request.force,
|
||||
delete_contents=request.delete_contents,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.queues import (
|
||||
GetDefaultResp,
|
||||
@@ -15,10 +17,12 @@ from apiserver.apimodels.queues import (
|
||||
DeleteMetadataRequest,
|
||||
GetNextTaskRequest,
|
||||
GetByIdRequest,
|
||||
GetAllRequest,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
@@ -51,16 +55,33 @@ def get_by_id(call: APICall):
|
||||
call.result.data_model = GetDefaultResp(id=queue.id, name=queue.name)
|
||||
|
||||
|
||||
def _hidden_query(data: dict) -> Q:
|
||||
"""
|
||||
1. Add only non-hidden queues search condition (unless specifically specified differently)
|
||||
"""
|
||||
hidden_tags = config.get("services.queues.hidden_tags", [])
|
||||
if (
|
||||
not hidden_tags
|
||||
or data.get("search_hidden")
|
||||
or data.get("id")
|
||||
or data.get("name")
|
||||
):
|
||||
return Q()
|
||||
|
||||
return Q(system_tags__nin=hidden_tags)
|
||||
|
||||
|
||||
@endpoint("queues.get_all_ex", min_version="2.4")
|
||||
def get_all_ex(call: APICall):
|
||||
def get_all_ex(call: APICall, company: str, request: GetAllRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company,
|
||||
company_id=company,
|
||||
query_dict=call.data,
|
||||
max_task_entries=call.data.pop("max_task_entries", None),
|
||||
query=_hidden_query(call.data),
|
||||
max_task_entries=request.max_task_entries,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
@@ -69,14 +90,15 @@ def get_all_ex(call: APICall):
|
||||
|
||||
|
||||
@endpoint("queues.get_all", min_version="2.4")
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company: str, request: GetAllRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company,
|
||||
company_id=company,
|
||||
query_dict=call.data,
|
||||
max_task_entries=call.data.pop("max_task_entries", None),
|
||||
query=_hidden_query(call.data),
|
||||
max_task_entries=request.max_task_entries,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
@@ -120,7 +142,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
queue_bll.delete(
|
||||
company_id=company_id, queue_id=req_model.queue, force=req_model.force
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
queue_id=req_model.queue,
|
||||
force=req_model.force,
|
||||
)
|
||||
call.result.data = {"deleted": 1}
|
||||
|
||||
@@ -135,11 +160,13 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
|
||||
|
||||
|
||||
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
|
||||
entry = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
|
||||
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
|
||||
entry = queue_bll.get_next_task(
|
||||
company_id=company_id, queue_id=request.queue, task_id=request.task
|
||||
)
|
||||
if entry:
|
||||
data = {"entry": entry.to_proper_dict()}
|
||||
if req_model.get_task_info:
|
||||
if request.get_task_info:
|
||||
task = Task.objects(id=entry.task).first()
|
||||
if task:
|
||||
data["task_info"] = {"company": task.company, "user": task.user}
|
||||
|
||||
375
apiserver/services/reports.py
Normal file
375
apiserver/services/reports.py
Normal file
@@ -0,0 +1,375 @@
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.reports import (
|
||||
CreateReportRequest,
|
||||
UpdateReportRequest,
|
||||
PublishReportRequest,
|
||||
ArchiveReportRequest,
|
||||
DeleteReportRequest,
|
||||
MoveReportRequest,
|
||||
GetTasksDataRequest,
|
||||
EventsRequest,
|
||||
GetAllRequest,
|
||||
)
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
|
||||
from apiserver.services.utils import process_include_subprojects, sort_tags_response
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL, ChangeStatusRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.events import (
|
||||
_get_task_or_model_index_companies,
|
||||
event_bll,
|
||||
_get_metrics_response,
|
||||
_get_metric_variants_from_request,
|
||||
_get_multitask_plots,
|
||||
)
|
||||
from apiserver.services.tasks import (
|
||||
escape_execution_parameters,
|
||||
_hidden_query,
|
||||
unprepare_from_saved,
|
||||
)
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
update_fields = {
|
||||
"name",
|
||||
"tags",
|
||||
"comment",
|
||||
"report",
|
||||
"report_assets",
|
||||
}
|
||||
|
||||
|
||||
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
|
||||
if only_fields and "type" not in only_fields:
|
||||
only_fields += ("type",)
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
only=only_fields,
|
||||
requires_write_access=requires_write_access,
|
||||
)
|
||||
if task.type != TaskType.report:
|
||||
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@endpoint("reports.update", response_data_model=UpdateResponse)
|
||||
def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
|
||||
task = _assert_report(
|
||||
task_id=request.task, company_id=company_id, only_fields=("status",),
|
||||
)
|
||||
|
||||
partial_update_dict = {
|
||||
field: value for field, value in call.data.items() if field in update_fields
|
||||
}
|
||||
if not partial_update_dict:
|
||||
return UpdateResponse(updated=0)
|
||||
|
||||
allowed_for_published = set(partial_update_dict.keys()).issubset(
|
||||
{"tags", "name", "comment"}
|
||||
)
|
||||
if task.status != TaskStatus.created and not allowed_for_published:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
more_updates = {"last_change": now, "last_changed_by": call.identity.user}
|
||||
if not allowed_for_published:
|
||||
more_updates["last_update"] = now
|
||||
|
||||
updated = task.update(upsert=False, **partial_update_dict, **more_updates)
|
||||
if not updated:
|
||||
return UpdateResponse(updated=0)
|
||||
|
||||
updated_tags = partial_update_dict.get("tags")
|
||||
if updated_tags:
|
||||
partial_update_dict["tags"] = sorted(updated_tags)
|
||||
updated_report = partial_update_dict.get("report")
|
||||
if updated_report:
|
||||
partial_update_dict["report"] = textwrap.shorten(updated_report, width=100)
|
||||
|
||||
return UpdateResponse(updated=updated, fields=partial_update_dict)
|
||||
|
||||
|
||||
def _ensure_reports_project(company: str, user: str, name: str):
|
||||
name = name.strip("/")
|
||||
_, _, basename = name.rpartition("/")
|
||||
if basename != reports_project_name:
|
||||
name = f"{name}/{reports_project_name}"
|
||||
|
||||
return project_bll.find_or_create(
|
||||
user=user,
|
||||
company=company,
|
||||
project_name=name,
|
||||
description="Reports project",
|
||||
system_tags=[reports_tag, EntityVisibility.hidden.value],
|
||||
)
|
||||
|
||||
|
||||
@endpoint("reports.create")
|
||||
def create_report(call: APICall, company_id: str, request: CreateReportRequest):
|
||||
user_id = call.identity.user
|
||||
project_id = request.project
|
||||
if request.project:
|
||||
project = Project.get_for_writing(
|
||||
company=company_id, id=project_id, _only=("name",)
|
||||
)
|
||||
project_name = project.name
|
||||
else:
|
||||
project_name = ""
|
||||
|
||||
project_id = _ensure_reports_project(
|
||||
company=company_id, user=user_id, name=project_name
|
||||
)
|
||||
task = task_bll.create(
|
||||
company=company_id,
|
||||
user=user_id,
|
||||
fields=dict(
|
||||
project=project_id,
|
||||
name=request.name,
|
||||
tags=request.tags,
|
||||
comment=request.comment,
|
||||
type=TaskType.report,
|
||||
system_tags=[reports_tag, EntityVisibility.hidden.value],
|
||||
),
|
||||
)
|
||||
task.save()
|
||||
|
||||
call.result.data = {"id": task.id, "project_id": project_id}
|
||||
|
||||
|
||||
def _delete_reports_project_if_empty(project_id):
|
||||
project = Project.objects(id=project_id).only("basename").first()
|
||||
if (
|
||||
project
|
||||
and project.basename == reports_project_name
|
||||
and Task.objects(project=project_id).count() == 0
|
||||
):
|
||||
project.delete()
|
||||
|
||||
|
||||
@endpoint("reports.get_all_ex")
|
||||
def get_all_ex(call: APICall, company_id, request: GetAllRequest):
|
||||
call_data = call.data
|
||||
call_data["type"] = TaskType.report
|
||||
call_data["include_subprojects"] = True
|
||||
|
||||
process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
def _get_task_metrics_from_request(
|
||||
task_ids: Sequence[str], request: EventsRequest
|
||||
) -> dict:
|
||||
task_metrics = {}
|
||||
for task in task_ids:
|
||||
task_dict = {}
|
||||
for mv in request.metrics:
|
||||
task_dict[mv.metric] = mv.variants
|
||||
task_metrics[task] = task_dict
|
||||
|
||||
return task_metrics
|
||||
|
||||
|
||||
@endpoint("reports.get_task_data")
|
||||
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
call_data = escape_execution_parameters(call)
|
||||
process_include_subprojects(call_data)
|
||||
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
res = {"tasks": tasks, **ret_params}
|
||||
if not (
|
||||
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
|
||||
):
|
||||
return res
|
||||
|
||||
task_ids = [task["id"] for task in tasks]
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
|
||||
if request.debug_images:
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
companies={
|
||||
t.id: t.company for t in chain.from_iterable(companies.values())
|
||||
},
|
||||
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
|
||||
iter_count=request.debug_images.iters,
|
||||
)
|
||||
res["debug_images"] = [
|
||||
r.to_struct() for r in _get_metrics_response(result.metric_events)
|
||||
]
|
||||
|
||||
if request.plots:
|
||||
res["plots"] = _get_multitask_plots(
|
||||
companies=companies,
|
||||
last_iters=request.plots.iters,
|
||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||
)[0]
|
||||
|
||||
if request.scalar_metrics_iter_histogram:
|
||||
res[
|
||||
"scalar_metrics_iter_histogram"
|
||||
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||
companies=companies,
|
||||
samples=request.scalar_metrics_iter_histogram.samples,
|
||||
key=request.scalar_metrics_iter_histogram.key,
|
||||
metric_variants=_get_metric_variants_from_request(
|
||||
request.scalar_metrics_iter_histogram.metrics
|
||||
),
|
||||
)
|
||||
|
||||
call.result.data = res
|
||||
|
||||
|
||||
@endpoint("reports.move")
|
||||
def move(call: APICall, company_id: str, request: MoveReportRequest):
|
||||
if not ("project" in call.data or request.project_name):
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"project or project_name is required"
|
||||
)
|
||||
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, only_fields=("project",),
|
||||
)
|
||||
user_id = call.identity.user
|
||||
project_name = request.project_name
|
||||
if not project_name:
|
||||
if request.project:
|
||||
project = Project.get_for_writing(
|
||||
company=company_id, id=request.project, _only=("name",)
|
||||
)
|
||||
project_name = project.name
|
||||
else:
|
||||
project_name = ""
|
||||
|
||||
project_id = _ensure_reports_project(
|
||||
company=company_id, user=user_id, name=project_name
|
||||
)
|
||||
|
||||
project_bll.move_under_project(
|
||||
entity_cls=Task,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
ids=[request.task],
|
||||
project=project_id,
|
||||
)
|
||||
|
||||
_delete_reports_project_if_empty(task.project)
|
||||
|
||||
return {"project_id": project_id}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"reports.publish", response_data_model=UpdateResponse,
|
||||
)
|
||||
def publish(call: APICall, company_id, request: PublishReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
updates = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=True,
|
||||
status_reason="",
|
||||
status_message=request.message,
|
||||
user_id=call.identity.user,
|
||||
).execute(published=datetime.utcnow())
|
||||
|
||||
call.result.data_model = UpdateResponse(**updates)
|
||||
|
||||
|
||||
@endpoint("reports.archive")
|
||||
def archive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
archived = task.update(
|
||||
status_message=request.message,
|
||||
status_reason="",
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
|
||||
return {"archived": archived}
|
||||
|
||||
|
||||
@endpoint("reports.unarchive")
|
||||
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
unarchived = task.update(
|
||||
status_message=request.message,
|
||||
status_reason="",
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
return {"unarchived": unarchived}
|
||||
|
||||
|
||||
# @endpoint("reports.share")
|
||||
# def share(call: APICall, company_id, request: ShareReportRequest):
|
||||
# _assert_report(
|
||||
# company_id=company_id, user_id=call.identity.user, task_id=request.task
|
||||
# )
|
||||
# call.result.data = {
|
||||
# "changed": task_bll.share_task(
|
||||
# company_id=company_id, task_ids=[request.task], share=request.share
|
||||
# )
|
||||
# }
|
||||
|
||||
|
||||
@endpoint("reports.delete")
|
||||
def delete(call: APICall, company_id, request: DeleteReportRequest):
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, only_fields=("project",),
|
||||
)
|
||||
if (
|
||||
task.status != TaskStatus.created
|
||||
and EntityVisibility.archived.value not in task.system_tags
|
||||
and not request.force
|
||||
):
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"due to status, use force=True",
|
||||
task=task.id,
|
||||
expected=TaskStatus.created,
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
task.delete()
|
||||
_delete_reports_project_if_empty(task.project)
|
||||
|
||||
call.result.data = {"deleted": 1}
|
||||
|
||||
|
||||
@endpoint("reports.get_tags")
|
||||
def get_tags(call: APICall, company_id: str, _):
|
||||
tags = Task.objects(company=company_id, type=TaskType.report).distinct(field="tags")
|
||||
call.result.data = sort_tags_response({"tags": tags})
|
||||
@@ -64,11 +64,12 @@ from apiserver.apimodels.tasks import (
|
||||
ResetBatchItem,
|
||||
CompletedRequest,
|
||||
CompletedResponse,
|
||||
GetAllReq,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
@@ -81,7 +82,6 @@ from apiserver.bll.task.artifacts import (
|
||||
Artifacts,
|
||||
)
|
||||
from apiserver.bll.task.hyperparams import HyperParams
|
||||
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
||||
from apiserver.bll.task.param_utils import (
|
||||
params_prepare_for_save,
|
||||
params_unprepare_from_saved,
|
||||
@@ -119,8 +119,8 @@ from apiserver.services.utils import (
|
||||
DockerCmdBackwardsCompatibility,
|
||||
escape_dict_field,
|
||||
unescape_dict_field,
|
||||
process_include_subprojects,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
@@ -135,11 +135,9 @@ queue_bll = QueueBLL()
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
NonResponsiveTasksWatchdog.start()
|
||||
|
||||
|
||||
def set_task_status_from_call(
|
||||
request: UpdateRequest, company_id, new_status=None, **set_fields
|
||||
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
|
||||
) -> dict:
|
||||
fields_resolver = SetFieldsResolver(set_fields)
|
||||
task = TaskBLL.get_task_with_access(
|
||||
@@ -174,6 +172,7 @@ def set_task_status_from_call(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
).execute(**fields_resolver.get_fields(task))
|
||||
|
||||
|
||||
@@ -207,17 +206,6 @@ def escape_execution_parameters(call: APICall) -> dict:
|
||||
return call_data
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
def _hidden_query(data: dict) -> Q:
|
||||
"""
|
||||
1. Add only non-hidden tasks search condition (unless specifically specified differently)
|
||||
@@ -228,22 +216,21 @@ def _hidden_query(data: dict) -> Q:
|
||||
return Q(system_tags__ne=EntityVisibility.hidden.value)
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
@endpoint("tasks.get_all_ex")
|
||||
def get_all_ex(call: APICall, company_id, request: GetAllReq):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
@@ -254,10 +241,9 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
@@ -269,16 +255,15 @@ def get_all(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
ret_params = {}
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
@@ -308,6 +293,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**stop_task(
|
||||
task_id=req_model.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=req_model.status_reason,
|
||||
force=req_model.force,
|
||||
@@ -325,6 +311,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
|
||||
func=partial(
|
||||
stop_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
@@ -346,7 +333,8 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
||||
call.result.data_model = UpdateResponse(
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
new_status=TaskStatus.stopped,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -362,7 +350,8 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
|
||||
res = StartedResponse(
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
new_status=TaskStatus.in_progress,
|
||||
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
|
||||
)
|
||||
@@ -376,7 +365,12 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
|
||||
)
|
||||
def failed(call: APICall, company_id, req_model: UpdateRequest):
|
||||
call.result.data_model = UpdateResponse(
|
||||
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed)
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
new_status=TaskStatus.failed,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -385,7 +379,11 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
|
||||
)
|
||||
def close(call: APICall, company_id, req_model: UpdateRequest):
|
||||
call.result.data_model = UpdateResponse(
|
||||
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed)
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
new_status=TaskStatus.closed)
|
||||
)
|
||||
|
||||
|
||||
@@ -490,12 +488,13 @@ def prepare_create_fields(
|
||||
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
|
||||
with translate_errors_context(
|
||||
field_does_not_exist_cls=errors.bad_request.ValidationError
|
||||
), TimingContext("code", "parse_call"):
|
||||
):
|
||||
fields = prepare_create_fields(call, **kwargs)
|
||||
task = task_bll.create(call, fields)
|
||||
task = task_bll.create(
|
||||
company=call.identity.company, user=call.identity.user, fields=fields
|
||||
)
|
||||
|
||||
with TimingContext("code", "validate"):
|
||||
task_bll.validate(task)
|
||||
task_bll.validate(task)
|
||||
|
||||
return task, fields
|
||||
|
||||
@@ -528,7 +527,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
task, fields = _validate_and_get_task_from_call(call)
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
||||
with translate_errors_context():
|
||||
task.save()
|
||||
_update_cached_tags(company_id, project=task.project, fields=fields)
|
||||
update_project_time(task.project)
|
||||
@@ -596,7 +595,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
company_id=company_id,
|
||||
id=task_id,
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(last_change=datetime.utcnow()),
|
||||
injected_update=dict(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
|
||||
),
|
||||
)
|
||||
if updated_count:
|
||||
new_project = updated_fields.get("project", task.project)
|
||||
@@ -629,7 +630,11 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
|
||||
raise errors.bad_request.MissingTaskFields(
|
||||
"Task has no script field", task=task.id
|
||||
)
|
||||
res = update_task(task, update_cmds=dict(script__requirements=requirements))
|
||||
res = update_task(
|
||||
task,
|
||||
user_id=call.identity.user,
|
||||
update_cmds=dict(script__requirements=requirements),
|
||||
)
|
||||
call.result.data_model = UpdateResponse(updated=res)
|
||||
if res:
|
||||
call.result.data_model.fields = {"script.requirements": requirements}
|
||||
@@ -664,7 +669,9 @@ def update_batch(call: APICall, company_id, _):
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
partial_update_dict.update(last_change=now)
|
||||
partial_update_dict.update(
|
||||
last_change=now, last_changed_by=call.identity.user,
|
||||
)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
)
|
||||
@@ -711,7 +718,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
|
||||
with translate_errors_context(
|
||||
field_does_not_exist_cls=errors.bad_request.ValidationError
|
||||
), TimingContext("code", "parse_and_validate"):
|
||||
):
|
||||
fields = prepare_create_fields(
|
||||
call, valid_fields=edit_fields, output=task.output, previous_task=task
|
||||
)
|
||||
@@ -728,7 +735,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
task_bll.validate(task_bll.create(call, fields))
|
||||
task_bll.validate(
|
||||
task_bll.create(
|
||||
company=call.identity.company, user=call.identity.user, fields=fields
|
||||
)
|
||||
)
|
||||
|
||||
# make sure field names do not end in mongoengine comparison operators
|
||||
fixed_fields = {
|
||||
@@ -737,7 +748,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
}
|
||||
if fixed_fields:
|
||||
now = datetime.utcnow()
|
||||
last_change = dict(last_change=now)
|
||||
last_change = dict(last_change=now, last_changed_by=call.identity.user)
|
||||
if not set(fields).issubset(Task.user_set_allowed()):
|
||||
last_change.update(last_update=now)
|
||||
fields.update(**last_change)
|
||||
@@ -774,6 +785,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
@@ -787,6 +799,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
@@ -831,6 +844,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
@@ -846,6 +860,7 @@ def delete_configuration(
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
@@ -862,12 +877,18 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queued, res = enqueue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
queue_name=request.queue_name,
|
||||
force=request.force,
|
||||
)
|
||||
if request.verify_watched_queue:
|
||||
res_queue = nested_get(res, ("fields", "execution.queue"))
|
||||
if res_queue:
|
||||
res["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
|
||||
|
||||
call.result.data_model = EnqueueResponse(queued=queued, **res)
|
||||
|
||||
|
||||
@@ -881,6 +902,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
func=partial(
|
||||
enqueue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -889,12 +911,20 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
extra = {}
|
||||
if request.verify_watched_queue and results:
|
||||
_id, (queued, res) = results[0]
|
||||
res_queue = nested_get(res, ("fields", "execution.queue"))
|
||||
if res_queue:
|
||||
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
|
||||
|
||||
call.result.data_model = EnqueueManyResponse(
|
||||
succeeded=[
|
||||
EnqueueBatchItem(id=_id, queued=bool(queued), **res)
|
||||
for _id, (queued, res) in results
|
||||
],
|
||||
failed=failures,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
@@ -907,6 +937,7 @@ def dequeue(call: APICall, company_id, request: UpdateRequest):
|
||||
dequeued, res = dequeue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
)
|
||||
@@ -923,6 +954,7 @@ def dequeue_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
dequeue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -944,10 +976,12 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
dequeued, cleanup_res, updates = reset_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
clear_all=request.clear_all,
|
||||
delete_external_artifacts=request.delete_external_artifacts,
|
||||
)
|
||||
|
||||
res = ResetResponse(**updates, dequeued=dequeued)
|
||||
@@ -970,10 +1004,12 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
func=partial(
|
||||
reset_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
clear_all=request.clear_all,
|
||||
delete_external_artifacts=request.delete_external_artifacts,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
@@ -1014,6 +1050,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
for task in tasks:
|
||||
archived += archive_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task=task,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1032,6 +1069,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
archive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1053,6 +1091,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
unarchive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1071,12 +1110,14 @@ def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
deleted, task, cleanup_res = delete_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
delete_external_artifacts=request.delete_external_artifacts,
|
||||
)
|
||||
if deleted:
|
||||
if request.move_to_trash:
|
||||
@@ -1091,12 +1132,14 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
delete_external_artifacts=request.delete_external_artifacts,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
@@ -1127,6 +1170,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
|
||||
updates = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1145,6 +1189,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
|
||||
func=partial(
|
||||
publish_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model
|
||||
if request.publish_model
|
||||
@@ -1171,7 +1216,8 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
res = CompletedResponse(
|
||||
**set_task_status_from_call(
|
||||
request,
|
||||
company_id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
new_status=TaskStatus.completed,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -1181,6 +1227,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
publish_res = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1212,6 +1259,7 @@ def add_or_update_artifacts(
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=True,
|
||||
@@ -1228,6 +1276,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=True,
|
||||
@@ -1251,7 +1300,7 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
|
||||
@endpoint("tasks.move", request_data_model=MoveRequest)
|
||||
def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
if not (request.project or request.project_name):
|
||||
if not ("project" in call.data or request.project_name):
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"project or project_name is required"
|
||||
)
|
||||
@@ -1295,7 +1344,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
|
||||
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
|
||||
delete_names = {
|
||||
@@ -1308,5 +1357,5 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
if names
|
||||
}
|
||||
|
||||
updated = task.update(last_change=datetime.utcnow(), **commands,)
|
||||
updated = task.update(last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,)
|
||||
return {"updated": updated}
|
||||
|
||||
@@ -12,7 +12,7 @@ from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import Role, User as AuthUser
|
||||
from apiserver.database.model.auth import Role
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.user import User
|
||||
from apiserver.database.utils import parse_from_call
|
||||
@@ -114,12 +114,6 @@ def get_current_user(call: APICall, company_id, _):
|
||||
user = res[0]
|
||||
user["role"] = call.identity.role
|
||||
|
||||
auth_user: AuthUser = AuthUser.objects(id=user_id, company=company_id).first()
|
||||
if not auth_user:
|
||||
raise errors.bad_request.InvalidUser("failed loading user")
|
||||
|
||||
user["created"] = auth_user.created
|
||||
|
||||
resp = {
|
||||
"user": user,
|
||||
"getting_started": config.get("apiserver.getting_started_info", None),
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
from apiserver.database.utils import partition_tags
|
||||
@@ -12,6 +13,17 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
|
||||
def process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
if not input_:
|
||||
return {}
|
||||
|
||||
@@ -42,7 +42,10 @@ worker_bll = WorkerBLL()
|
||||
def get_all(call: APICall, company_id: str, request: GetAllRequest):
|
||||
call.result.data_model = GetAllResponse(
|
||||
workers=worker_bll.get_all_with_projection(
|
||||
company_id, request.last_seen, tags=request.tags
|
||||
company_id,
|
||||
request.last_seen,
|
||||
tags=request.tags,
|
||||
system_tags=request.system_tags,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -53,6 +56,9 @@ def register(call: APICall, company_id, request: RegisterRequest):
|
||||
timeout = request.timeout
|
||||
queues = request.queues
|
||||
|
||||
if not timeout:
|
||||
timeout = config.get("apiserver.workers.default_timeout", 10 * 60)
|
||||
|
||||
if not timeout or timeout <= 0:
|
||||
raise bad_request.WorkerRegistrationFailed(
|
||||
"invalid timeout", timeout=timeout, worker=worker
|
||||
@@ -66,6 +72,7 @@ def register(call: APICall, company_id, request: RegisterRequest):
|
||||
queues=queues,
|
||||
timeout=timeout,
|
||||
tags=request.tags,
|
||||
system_tags=request.system_tags,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,6 +91,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
|
||||
ip=call.real_ip,
|
||||
report=request,
|
||||
tags=request.tags,
|
||||
system_tags=request.system_tags,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -60,12 +60,12 @@ class TestService(TestCase, TestServiceInterface):
|
||||
def update_missing(target: dict, **update):
|
||||
target.update({k: v for k, v in update.items() if k not in target})
|
||||
|
||||
def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str:
|
||||
def create_temp(self, service, *, client=None, delete_params=None, object_name="", **kwargs) -> str:
|
||||
return self._create_temp_helper(
|
||||
service=service,
|
||||
create_endpoint="create",
|
||||
delete_endpoint="delete",
|
||||
object_name=service.rstrip("s"),
|
||||
object_name=object_name or service.rstrip("s"),
|
||||
create_params=kwargs,
|
||||
client=client,
|
||||
delete_params=delete_params,
|
||||
|
||||
@@ -7,9 +7,6 @@ class TestBatchOperations(TestService):
|
||||
comment = "this is a comment"
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_tasks(self):
|
||||
tasks = [self._temp_task() for _ in range(2)]
|
||||
models = [
|
||||
|
||||
@@ -30,6 +30,11 @@ class TestMoveUnderProject(TestService):
|
||||
self.assertEqual(p2_name, projects[0].name)
|
||||
self.api.projects.delete(project=project2, force=True)
|
||||
|
||||
# move to the root project
|
||||
self.assertEqual(None, self.api.tasks.move(ids=[task], project=None).project_id)
|
||||
tasks = self.api.tasks.get_all_ex(id=[task]).tasks
|
||||
self.assertEqual(None, tasks[0].get("project"))
|
||||
|
||||
# model move into existing project referenced by name
|
||||
model = self._temp_model()
|
||||
self.api.models.move(ids=[model], project_name=self.entity_name)
|
||||
|
||||
50
apiserver/tests/automated/test_pipelines.py
Normal file
50
apiserver/tests/automated/test_pipelines.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Tuple
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestPipelines(TestService):
|
||||
def test_start_pipeline(self):
|
||||
queue = self.api.queues.get_default().id
|
||||
task_name = "pipelines test"
|
||||
project, task = self._temp_project_and_task(name=task_name)
|
||||
args = [{"name": "hello", "value": "test"}]
|
||||
|
||||
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
|
||||
pipeline_task = res.pipeline
|
||||
try:
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
finally:
|
||||
self.api.tasks.delete(task=pipeline_task, force=True)
|
||||
|
||||
def _temp_project_and_task(self, name) -> Tuple[str, str]:
|
||||
project = self.create_temp(
|
||||
"projects", name=name, description="test", delete_params=dict(force=True),
|
||||
)
|
||||
|
||||
return (
|
||||
project,
|
||||
self.create_temp(
|
||||
"tasks",
|
||||
name=name,
|
||||
type="testing",
|
||||
input=dict(view=dict()),
|
||||
project=project,
|
||||
system_tags=["pipeline"],
|
||||
),
|
||||
)
|
||||
@@ -21,7 +21,7 @@ class TestQueues(TestService):
|
||||
|
||||
def test_queue_metrics(self):
|
||||
queue_id = self._temp_queue("TestTempQueue")
|
||||
task1 = self._create_temp_queued_task("temp task 1", queue_id)
|
||||
self._create_temp_queued_task("temp task 1", queue_id)
|
||||
time.sleep(1)
|
||||
task2 = self._create_temp_queued_task("temp task 2", queue_id)
|
||||
self.api.queues.get_next_task(queue=queue_id)
|
||||
@@ -36,6 +36,27 @@ class TestQueues(TestService):
|
||||
)
|
||||
self.assertMetricQueues(res["queues"], queue_id)
|
||||
|
||||
def test_hidden_queues(self):
|
||||
hidden_name = "TestHiddenQueue"
|
||||
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])
|
||||
non_hidden_queue = self._temp_queue("TestNonHiddenQueue")
|
||||
|
||||
queues = self.api.queues.get_all_ex().queues
|
||||
ids = {q.id for q in queues}
|
||||
self.assertFalse(hidden_queue in ids)
|
||||
self.assertTrue(non_hidden_queue in ids)
|
||||
|
||||
queues = self.api.queues.get_all_ex(search_hidden=True).queues
|
||||
ids = {q.id for q in queues}
|
||||
self.assertTrue(hidden_queue in ids)
|
||||
self.assertTrue(non_hidden_queue in ids)
|
||||
|
||||
queues = self.api.queues.get_all_ex(name=f"^{hidden_name}$").queues
|
||||
self.assertEqual(hidden_queue, queues[0].id)
|
||||
|
||||
queues = self.api.queues.get_all_ex(id=[hidden_queue]).queues
|
||||
self.assertEqual(hidden_queue, queues[0].id)
|
||||
|
||||
def test_reset_task(self):
|
||||
queue = self._temp_queue("TestTempQueue")
|
||||
task = self._temp_task("TempTask", is_development=True)
|
||||
@@ -60,6 +81,19 @@ class TestQueues(TestService):
|
||||
self.assertQueueTasks(res.queue, [task])
|
||||
self.assertTaskTags(task, system_tags=[])
|
||||
|
||||
def test_dequeue_from_deleted_queue(self):
|
||||
queue = self._temp_queue("TestTempQueue")
|
||||
task_name = "TempDevTask"
|
||||
task = self._temp_task(task_name)
|
||||
|
||||
self.api.tasks.enqueue(task=task, queue=queue)
|
||||
res = self.api.tasks.get_by_id(task=task)
|
||||
self.assertEqual(res.task.status, "queued")
|
||||
|
||||
self.api.queues.delete(queue=queue, force=True)
|
||||
res = self.api.tasks.get_by_id(task=task)
|
||||
self.assertEqual(res.task.status, "created")
|
||||
|
||||
def test_max_queue_entries(self):
|
||||
queue = self._temp_queue("TestTempQueue")
|
||||
tasks = [
|
||||
@@ -193,8 +227,8 @@ class TestQueues(TestService):
|
||||
sorted(queue.workers, key=sort_key), sorted(workers, key=sort_key)
|
||||
)
|
||||
|
||||
def _temp_queue(self, queue_name, tags=None):
|
||||
return self.create_temp("queues", name=queue_name, tags=tags)
|
||||
def _temp_queue(self, queue_name, **kwargs):
|
||||
return self.create_temp("queues", name=queue_name, **kwargs)
|
||||
|
||||
def _temp_task(self, task_name, is_testing=False, is_development=False):
|
||||
task_input = dict(
|
||||
|
||||
208
apiserver/tests/automated/test_reports.py
Normal file
208
apiserver/tests/automated/test_reports.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import re
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.tests.automated import TestService
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class TestReports(TestService):
|
||||
def _delete_project(self, name):
|
||||
existing_project = first(
|
||||
self.api.projects.get_all_ex(
|
||||
name=f"^{re.escape(name)}$", search_hidden=True
|
||||
).projects
|
||||
)
|
||||
if existing_project:
|
||||
self.api.projects.delete(
|
||||
project=existing_project.id, force=True, delete_contents=True
|
||||
)
|
||||
|
||||
def test_create_update_move(self):
|
||||
task_name = "Rep1"
|
||||
comment = "My report"
|
||||
tags = ["hello"]
|
||||
|
||||
# report creates a hidden task under hidden .reports subproject
|
||||
self._delete_project(".reports")
|
||||
task_id = self._temp_report(name=task_name, comment=comment, tags=tags)
|
||||
task = self.api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(task.name, task_name)
|
||||
self.assertEqual(task.comment, comment)
|
||||
self.assertEqual(set(task.tags), set(tags))
|
||||
self.assertEqual(task.type, "report")
|
||||
self.assertEqual(set(task.system_tags), {"hidden", "reports"})
|
||||
projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects
|
||||
self.assertEqual(len(projects), 0)
|
||||
project = self.api.projects.get_all_ex(
|
||||
name=r"^\.reports$", search_hidden=True
|
||||
).projects[0]
|
||||
self.assertEqual(project.id, task.project.id)
|
||||
self.assertEqual(set(project.system_tags), {"hidden", "reports"})
|
||||
ret = self.api.reports.get_tags()
|
||||
self.assertEqual(ret.tags, sorted(tags))
|
||||
|
||||
# update is working on draft reports
|
||||
new_comment = "My new comment"
|
||||
res = self.api.reports.update(
|
||||
task=task_id,
|
||||
comment=new_comment,
|
||||
tags=[],
|
||||
report_assets=["file://test.jpg"],
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
task = self.api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(task.name, task_name)
|
||||
self.assertEqual(task.comment, new_comment)
|
||||
self.assertEqual(task.tags, [])
|
||||
ret = self.api.reports.get_tags()
|
||||
self.assertEqual(ret.tags, [])
|
||||
self.assertEqual(task.report_assets, ["file://test.jpg"])
|
||||
self.api.reports.publish(task=task_id)
|
||||
with self.api.raises(errors.bad_request.InvalidTaskStatus):
|
||||
self.api.reports.update(task=task_id, report="New report text")
|
||||
|
||||
# update on tags or rename can be done for published report too
|
||||
self.api.reports.update(
|
||||
task=task_id, name="new name", tags=["test"], comment="Yet another comment"
|
||||
)
|
||||
task = self.api.tasks.get_all_ex(id=[task_id]).tasks[0]
|
||||
self.assertEqual(task.tags, ["test"])
|
||||
self.assertEqual(task.name, "new name")
|
||||
self.assertEqual(task.comment, "Yet another comment")
|
||||
|
||||
# move under another project autodeletes the empty project
|
||||
new_project_name = "Reports Test"
|
||||
self._delete_project(new_project_name)
|
||||
task2_id = self._temp_report(name="Rep2")
|
||||
new_project_id = self.api.reports.move(
|
||||
task=task_id, project_name=new_project_name
|
||||
).project_id
|
||||
new_project = self.api.projects.get_all_ex(id=[new_project_id]).projects[0]
|
||||
self.assertEqual(new_project.name, f"{new_project_name}/.reports")
|
||||
self.assertEqual(set(new_project.system_tags), {"hidden", "reports"})
|
||||
self.assertEqual(len(self.api.projects.get_all_ex(id=project.id).projects), 1)
|
||||
self.api.reports.move(task=task2_id, project=new_project_id)
|
||||
self.assertEqual(len(self.api.projects.get_all_ex(id=project.id).projects), 0)
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
project=new_project_id, search_hidden=True
|
||||
).tasks
|
||||
self.assertTrue({task_id, task2_id}.issubset({t.id for t in tasks}))
|
||||
|
||||
project_id = self.api.reports.move(task=task2_id, project=None).project_id
|
||||
project = self.api.projects.get_all_ex(id=[project_id]).projects[0]
|
||||
self.assertEqual(project.get("parent"), None)
|
||||
self.assertEqual(project.name, ".reports")
|
||||
|
||||
def test_reports_search(self):
|
||||
report_task = self._temp_report(name="Rep1")
|
||||
non_report_task = self._temp_task(name="hello")
|
||||
res = self.api.reports.get_all_ex(
|
||||
_any_={"pattern": "hello", "fields": ["name", "id", "tags", "report"]}
|
||||
).tasks
|
||||
self.assertEqual(len(res), 0)
|
||||
|
||||
self.api.reports.update(task=report_task, report="hello world")
|
||||
res = self.api.reports.get_all_ex(
|
||||
_any_={"pattern": "hello", "fields": ["name", "id", "tags", "report"]}
|
||||
).tasks
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(res[0].id, report_task)
|
||||
|
||||
def test_reports_task_data(self):
|
||||
report_task = self._temp_report(name="Rep1")
|
||||
non_report_task = self._temp_task(name="hello")
|
||||
debug_image_events = [
|
||||
self._create_task_event(
|
||||
task=non_report_task,
|
||||
type_="training_debug_image",
|
||||
iteration=1,
|
||||
metric=f"Metric_{m}",
|
||||
variant=f"Variant_{v}",
|
||||
url=f"{m}_{v}",
|
||||
)
|
||||
for m in range(2)
|
||||
for v in range(2)
|
||||
]
|
||||
plot_events = [
|
||||
self._create_task_event(
|
||||
task=non_report_task,
|
||||
type_="plot",
|
||||
iteration=1,
|
||||
metric=f"Metric_{m}",
|
||||
variant=f"Variant_{v}",
|
||||
plot_str=f"Hello plot",
|
||||
)
|
||||
for m in range(2)
|
||||
for v in range(2)
|
||||
]
|
||||
self.send_batch([*debug_image_events, *plot_events])
|
||||
|
||||
res = self.api.reports.get_task_data(
|
||||
id=[non_report_task], only_fields=["name"],
|
||||
)
|
||||
self.assertEqual(len(res.tasks), 1)
|
||||
self.assertEqual(res.tasks[0].id, non_report_task)
|
||||
self.assertFalse(any(field in res for field in ("plots", "debug_images")))
|
||||
|
||||
res = self.api.reports.get_task_data(
|
||||
id=[non_report_task],
|
||||
only_fields=["name"],
|
||||
debug_images={"metrics": []},
|
||||
plots={"metrics": [{"metric": "Metric_1"}]},
|
||||
)
|
||||
self.assertEqual(len(res.debug_images), 1)
|
||||
task_events = res.debug_images[0]
|
||||
self.assertEqual(task_events.task, non_report_task)
|
||||
self.assertEqual(len(task_events.iterations), 1)
|
||||
self.assertEqual(len(task_events.iterations[0].events), 4)
|
||||
|
||||
self.assertEqual(len(res.plots), 1)
|
||||
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
|
||||
tasks = nested_get(res.plots, (m, v))
|
||||
self.assertEqual(len(tasks), 1)
|
||||
task_plots = tasks[non_report_task]
|
||||
self.assertEqual(len(task_plots), 1)
|
||||
iter_plots = task_plots["1"]
|
||||
self.assertEqual(iter_plots.name, "hello")
|
||||
self.assertEqual(len(iter_plots.plots), 1)
|
||||
ev = iter_plots.plots[0]
|
||||
self.assertEqual(ev["metric"], m)
|
||||
self.assertEqual(ev["variant"], v)
|
||||
self.assertEqual(ev["task"], non_report_task)
|
||||
self.assertEqual(ev["iter"], 1)
|
||||
|
||||
@staticmethod
|
||||
def _create_task_event(type_, task, iteration, **kwargs):
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type_,
|
||||
"task": task,
|
||||
"iter": iteration,
|
||||
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _temp_report(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"reports",
|
||||
name=name,
|
||||
object_name="task",
|
||||
delete_params={"force": True},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_task(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
name=name,
|
||||
type="training",
|
||||
delete_params={"force": True},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def send_batch(self, events):
|
||||
_, data = self.api.send_batch("events.add_batch", events)
|
||||
return data
|
||||
@@ -14,17 +14,91 @@ from apiserver.tests.automated import TestService
|
||||
class TestSubProjects(TestService):
|
||||
def test_dataset_stats(self):
|
||||
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
|
||||
res = self.api.organization.get_entities_count(datasets={"system_tags": ["dataset"]})
|
||||
res = self.api.organization.get_entities_count(
|
||||
datasets={"system_tags": ["dataset"]}
|
||||
)
|
||||
self.assertEqual(res.datasets, 1)
|
||||
|
||||
task = self._temp_task(project=project)
|
||||
data = self.api.projects.get_all_ex(id=[project], include_dataset_stats=True).projects[0]
|
||||
data = self.api.projects.get_all_ex(
|
||||
id=[project], include_dataset_stats=True
|
||||
).projects[0]
|
||||
self.assertIsNone(data.dataset_stats)
|
||||
|
||||
self.api.tasks.edit(task=task, runtime={"ds_file_count": 2, "ds_total_size": 1000})
|
||||
data = self.api.projects.get_all_ex(id=[project], include_dataset_stats=True).projects[0]
|
||||
self.api.tasks.edit(
|
||||
task=task, runtime={"ds_file_count": 2, "ds_total_size": 1000}
|
||||
)
|
||||
data = self.api.projects.get_all_ex(
|
||||
id=[project], include_dataset_stats=True
|
||||
).projects[0]
|
||||
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
|
||||
|
||||
def test_query_children(self):
|
||||
test_root_name = "TestQueryChildren"
|
||||
test_root = self._temp_project(name=test_root_name)
|
||||
dataset_tags = ["hello", "world"]
|
||||
dataset_project = self._temp_project(
|
||||
name=f"{test_root_name}/Project1/Dataset",
|
||||
system_tags=["dataset"],
|
||||
tags=dataset_tags,
|
||||
)
|
||||
self._temp_task(
|
||||
name="dataset task",
|
||||
type="data_processing",
|
||||
system_tags=["dataset"],
|
||||
project=dataset_project,
|
||||
)
|
||||
self._temp_task(name="regular task", project=dataset_project)
|
||||
pipeline_project = self._temp_project(
|
||||
name=f"{test_root_name}/Project2/Pipeline", system_tags=["pipeline"]
|
||||
)
|
||||
self._temp_task(
|
||||
name="pipeline task",
|
||||
type="controller",
|
||||
system_tags=["pipeline"],
|
||||
project=pipeline_project,
|
||||
)
|
||||
self._temp_task(name="regular task", project=pipeline_project)
|
||||
report_project = self._temp_project(name=f"{test_root_name}/Project3")
|
||||
self._temp_report(name="test report", project=report_project)
|
||||
self._temp_task(name="regular task", project=report_project)
|
||||
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root], shallow_search=True, include_stats=True
|
||||
).projects
|
||||
self.assertEqual(
|
||||
{p.basename for p in projects}, {f"Project{idx+1}" for idx in range(3)}
|
||||
)
|
||||
for p in projects:
|
||||
self.assertEqual(
|
||||
p.stats.active.total_tasks,
|
||||
2
|
||||
if p.basename in ("Project1", "Project2")
|
||||
else 1
|
||||
)
|
||||
|
||||
for i, type_ in enumerate(("dataset", "pipeline", "report")):
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root],
|
||||
children_type=type_,
|
||||
shallow_search=True,
|
||||
include_stats=True,
|
||||
check_own_contents=True,
|
||||
).projects
|
||||
self.assertEqual({p.basename for p in projects}, {f"Project{i+1}"})
|
||||
p = projects[0]
|
||||
if type_ in ("dataset",):
|
||||
self.assertEqual(p.own_datasets, 1)
|
||||
self.assertIsNone(p.get("own_tasks"))
|
||||
self.assertEqual(p.stats.datasets.count, 1)
|
||||
self.assertEqual(p.stats.datasets.tags, dataset_tags)
|
||||
else:
|
||||
self.assertEqual(p.own_tasks, 0)
|
||||
self.assertIsNone(p.get("own_datasets"))
|
||||
self.assertEqual(
|
||||
p.stats.active.total_tasks, 1 if p.basename != "Project4" else 0
|
||||
)
|
||||
|
||||
def test_project_aggregations(self):
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
@@ -52,6 +126,10 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(res.types, [])
|
||||
res = self.api.projects.get_task_parents(projects=[project])
|
||||
self.assertEqual(res.parents, [])
|
||||
res = self.api.organization.get_entities_count(
|
||||
projects={"id": [project]}, active_users=[user]
|
||||
)
|
||||
self.assertEqual(res.projects, 0)
|
||||
|
||||
# test aggregations with non-empty subprojects
|
||||
task1 = self._temp_task(project=child)
|
||||
@@ -64,7 +142,9 @@ class TestSubProjects(TestService):
|
||||
res = self.api.projects.get_all_ex(id=[project], include_stats=True)
|
||||
self._assert_ids(res.projects, [project])
|
||||
self.assertEqual(res.projects[0].stats.active.total_tasks, 3)
|
||||
res = self.api.projects.get_all_ex(id=[project], active_users=[user], include_stats=True)
|
||||
res = self.api.projects.get_all_ex(
|
||||
id=[project], active_users=[user], include_stats=True
|
||||
)
|
||||
self._assert_ids(res.projects, [project])
|
||||
self.assertEqual(res.projects[0].stats.active.total_tasks, 2)
|
||||
res = self.api.projects.get_task_parents(projects=[project])
|
||||
@@ -73,6 +153,10 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(res.frameworks, [framework])
|
||||
res = self.api.tasks.get_types(projects=[project])
|
||||
self.assertEqual(res.types, ["testing"])
|
||||
res = self.api.organization.get_entities_count(
|
||||
projects={"id": [project]}, active_users=[user]
|
||||
)
|
||||
self.assertEqual(res.projects, 1)
|
||||
|
||||
def _assert_ids(self, actual: Sequence[dict], expected: Sequence[str]):
|
||||
self.assertEqual([a["id"] for a in actual], expected)
|
||||
@@ -280,12 +364,21 @@ class TestSubProjects(TestService):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_task(self, client=None, **kwargs):
|
||||
def _temp_report(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"reports",
|
||||
name=name,
|
||||
object_name="task",
|
||||
delete_params=self.delete_params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_task(self, client=None, name=None, type=None, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
delete_params=self.delete_params,
|
||||
type="testing",
|
||||
name=db_id(),
|
||||
type=type or "testing",
|
||||
name=name or db_id(),
|
||||
input=dict(view=dict()),
|
||||
client=client,
|
||||
**kwargs,
|
||||
|
||||
@@ -6,17 +6,26 @@ from typing import Sequence, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors.errors.bad_request import EventsNotAdded
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
)
|
||||
return self.create_temp("tasks", **task_input)
|
||||
return self.create_temp(
|
||||
"tasks", delete_paramse=self.delete_params, **task_input
|
||||
)
|
||||
|
||||
def _temp_model(self, name="test model events", **kwargs):
|
||||
self.update_missing(kwargs, name=name, uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _create_task_event(type_, task, iteration, **kwargs):
|
||||
@@ -97,9 +106,7 @@ class TestTaskEvents(TestService):
|
||||
res.variants[variant]["iter"],
|
||||
[x or special_iteration for x in range(iter_count)],
|
||||
)
|
||||
self.assertEqual(
|
||||
res.variants[variant]["y"], list(range(iter_count))
|
||||
)
|
||||
self.assertEqual(res.variants[variant]["y"], list(range(iter_count)))
|
||||
|
||||
# but not in the histogram
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task)
|
||||
@@ -133,8 +140,7 @@ class TestTaskEvents(TestService):
|
||||
task=task, batch_size=100, metric=metric_param, count_total=True
|
||||
)
|
||||
self.assertEqual(
|
||||
res.variants[variant]["y"],
|
||||
[y or new_value for y in range(iter_count)],
|
||||
res.variants[variant]["y"], [y or new_value for y in range(iter_count)],
|
||||
)
|
||||
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
@@ -170,7 +176,49 @@ class TestTaskEvents(TestService):
|
||||
metric_data = first(first(task_data.last_metrics.values()).values())
|
||||
self.assertEqual(iter_count - 1, metric_data.value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value_iteration)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
self.assertEqual(0, metric_data.min_value_iteration)
|
||||
|
||||
res = self.api.events.get_task_latest_scalar_values(task=task)
|
||||
self.assertEqual(iter_count - 1, res.last_iter)
|
||||
|
||||
def test_model_events(self):
|
||||
model = self._temp_model(ready=False)
|
||||
|
||||
# task log events are not allowed
|
||||
log_event = self._create_task_event(
|
||||
"log",
|
||||
task=model,
|
||||
iteration=0,
|
||||
msg=f"This is a log message",
|
||||
model_event=True,
|
||||
)
|
||||
with self.api.raises(errors.bad_request.EventsNotAdded):
|
||||
self.send(log_event)
|
||||
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", model, iteration),
|
||||
"metric": f"Metric{metric_idx}",
|
||||
"variant": f"Variant{variant_idx}",
|
||||
"value": iteration,
|
||||
"model_event": True,
|
||||
}
|
||||
for iteration in range(2)
|
||||
for metric_idx in range(5)
|
||||
for variant_idx in range(5)
|
||||
]
|
||||
self.send_batch(events)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(
|
||||
task=model, model_events=True
|
||||
)
|
||||
self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)])
|
||||
metric_data = data.Metric0
|
||||
self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)])
|
||||
variant_data = metric_data.Variant0
|
||||
self.assertEqual(variant_data.x, [0, 1])
|
||||
self.assertEqual(variant_data.y, [0.0, 1.0])
|
||||
|
||||
def test_error_events(self):
|
||||
task = self._temp_task()
|
||||
@@ -555,7 +603,8 @@ class TestTaskEvents(TestService):
|
||||
return data
|
||||
|
||||
def send(self, event):
|
||||
self.api.send("events.add", event)
|
||||
_, data = self.api.send("events.add", event)
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -26,57 +26,55 @@ class TestTaskPlots(TestService):
|
||||
def test_get_plot_sample(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
variants = ["Variant1", "Variant2"]
|
||||
|
||||
# test empty
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric, variant=variant
|
||||
)
|
||||
res = self.api.events.get_plot_sample(task=task, metric=metric)
|
||||
self.assertEqual(res.min_iteration, None)
|
||||
self.assertEqual(res.max_iteration, None)
|
||||
self.assertEqual(res.event, None)
|
||||
self.assertEqual(res.events, [])
|
||||
|
||||
# test existing events
|
||||
iterations = 10
|
||||
iterations = 5
|
||||
events = [
|
||||
self._create_task_event(
|
||||
task=task,
|
||||
iteration=n,
|
||||
iteration=n // len(variants),
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
variant=variants[n % len(variants)],
|
||||
plot_str=f"Test plot str {n}",
|
||||
)
|
||||
for n in range(iterations)
|
||||
for n in range(iterations * len(variants))
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
# if iteration is not specified then return the event from the last one
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric, variant=variant
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-1])
|
||||
res = self.api.events.get_plot_sample(task=task, metric=metric)
|
||||
self._assertEqualEvents(res.events, events[-len(variants) :])
|
||||
self.assertEqual(res.max_iteration, iterations - 1)
|
||||
self.assertEqual(res.min_iteration, 0)
|
||||
self.assertTrue(res.scroll_id)
|
||||
|
||||
# else from the specific iteration
|
||||
iteration = 8
|
||||
iteration = 3
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
iteration=iteration,
|
||||
scroll_id=res.scroll_id,
|
||||
task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
|
||||
)
|
||||
self._assertEqualEvents(
|
||||
res.events,
|
||||
events[iteration * len(variants) : (iteration + 1) * len(variants)],
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[iteration])
|
||||
|
||||
def test_next_plot_sample(self):
|
||||
task = self._temp_task()
|
||||
metric1 = "Metric1"
|
||||
variant1 = "Variant1"
|
||||
metric2 = "Metric2"
|
||||
variant2 = "Variant2"
|
||||
metrics = [(metric1, variant1), (metric2, variant2)]
|
||||
metrics = [
|
||||
(metric1, "variant1"),
|
||||
(metric1, "variant2"),
|
||||
(metric2, "variant3"),
|
||||
(metric2, "variant4"),
|
||||
]
|
||||
# test existing events
|
||||
events = [
|
||||
self._create_task_event(
|
||||
@@ -93,57 +91,72 @@ class TestTaskPlots(TestService):
|
||||
|
||||
# single metric navigation
|
||||
# init scroll
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric1, variant=variant1
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
res = self.api.events.get_plot_sample(task=task, metric=metric1)
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
|
||||
# navigate forwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
||||
)
|
||||
self.assertEqual(res.event, None)
|
||||
self.assertEqual(res.events, [])
|
||||
|
||||
# navigate backwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-4])
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, None)
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, events[-8:-6])
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, [])
|
||||
|
||||
# all metrics navigation
|
||||
# init scroll
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric1, variant=variant1, navigate_current_metric=False
|
||||
task=task, metric=metric1, navigate_current_metric=False
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
|
||||
# navigate forwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-1])
|
||||
self._assertEqualEvents(res.events, events[-2:])
|
||||
|
||||
# navigate backwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-3])
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, events[-6:-4])
|
||||
|
||||
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
|
||||
if ev2 is None:
|
||||
self.assertIsNone(ev1)
|
||||
return
|
||||
self.assertIsNotNone(ev1)
|
||||
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
|
||||
self.assertEqual(ev1[field], ev2[field])
|
||||
# next_iteration
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, next_iteration=True
|
||||
)
|
||||
self._assertEqualEvents(res.events, [])
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task,
|
||||
scroll_id=res.scroll_id,
|
||||
next_iteration=True,
|
||||
navigate_earlier=False,
|
||||
)
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
self.assertTrue(all(ev.iter == 1 for ev in res.events))
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task,
|
||||
scroll_id=res.scroll_id,
|
||||
next_iteration=True,
|
||||
navigate_earlier=False,
|
||||
)
|
||||
self._assertEqualEvents(res.events, [])
|
||||
|
||||
def _assertEqualEvents(
|
||||
self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
|
||||
):
|
||||
self.assertEqual(len(ev_source), len(ev_target))
|
||||
|
||||
def compare_event(ev1, ev2):
|
||||
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
|
||||
self.assertEqual(ev1[field], ev2[field])
|
||||
|
||||
for e1, e2 in zip(ev_source, ev_target):
|
||||
compare_event(e1, e2)
|
||||
|
||||
def test_task_plots(self):
|
||||
task = self._temp_task()
|
||||
@@ -223,12 +236,15 @@ class TestTaskPlots(TestService):
|
||||
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||
return res.scroll_id
|
||||
|
||||
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
|
||||
expected_variants = set(
|
||||
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
|
||||
)
|
||||
for metric_data in res.metrics:
|
||||
self.assertEqual(len(metric_data.iterations), iterations)
|
||||
for it_data in metric_data.iterations:
|
||||
self.assertEqual(
|
||||
set((e.metric, e.variant) for e in it_data.events), expected_variants
|
||||
set((e.metric, e.variant) for e in it_data.events),
|
||||
expected_variants,
|
||||
)
|
||||
|
||||
return res.scroll_id
|
||||
@@ -266,7 +282,7 @@ class TestTaskPlots(TestService):
|
||||
task=task,
|
||||
metric=metric,
|
||||
iterations=iterations,
|
||||
variants=len(variants)
|
||||
variants=len(variants),
|
||||
)
|
||||
|
||||
# test forward navigation
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user