mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
702b6dc9c8 | ||
|
|
db15f235e4 | ||
|
|
8c347f8fa9 | ||
|
|
768c3d80ff | ||
|
|
a5c3ef6385 | ||
|
|
11b7a384af | ||
|
|
9a70ade4a6 | ||
|
|
91ce140901 | ||
|
|
49084a9c49 | ||
|
|
8a99eb6812 | ||
|
|
811ab2bf4f | ||
|
|
3752db122b | ||
|
|
439911b84c | ||
|
|
262a301e28 | ||
|
|
a604451b01 | ||
|
|
88a7773621 | ||
|
|
35c4061992 | ||
|
|
4684fd5b74 | ||
|
|
e08123fcc0 | ||
|
|
e713e876eb | ||
|
|
c2cc788319 | ||
|
|
da8315d0db | ||
|
|
4ac6f88278 | ||
|
|
a7865ccbec | ||
|
|
ec14f327c6 | ||
|
|
a03b24d6b6 | ||
|
|
cb71ef8e47 | ||
|
|
8678fbc995 | ||
|
|
58df8f201a | ||
|
|
f4bf16c156 | ||
|
|
942f996237 | ||
|
|
c1e7f8f9c1 | ||
|
|
274c487b37 | ||
|
|
cc0129a800 | ||
|
|
388dd1b01f | ||
|
|
d62ecb5e6e | ||
|
|
6d507616b3 | ||
|
|
d0252a6dd9 |
@@ -41,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
)
|
||||
],
|
||||
)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
@@ -148,18 +149,23 @@ class MultiTasksRequestBase(Base):
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
pass
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class TaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class MultiTaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
|
||||
|
||||
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
|
||||
@@ -5,8 +5,9 @@ from apiserver.apimodels import DictField, callable_default
|
||||
|
||||
|
||||
class GetSupportedModesRequest(Base):
|
||||
state = StringField(help_text="ASCII base64 encoded application state")
|
||||
callback_url_prefix = StringField()
|
||||
pass
|
||||
# state = StringField(help_text="ASCII base64 encoded application state")
|
||||
# callback_url_prefix = StringField()
|
||||
|
||||
|
||||
class BasicGuestMode(Base):
|
||||
|
||||
@@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
verify_watched_queue = fields.BoolField(default=False)
|
||||
|
||||
@@ -33,6 +33,7 @@ class ProjectOrNoneRequest(models.Base):
|
||||
|
||||
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
|
||||
model_metrics = fields.BoolField(default=False)
|
||||
ids = fields.ListField(str)
|
||||
|
||||
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
@@ -72,6 +73,7 @@ class MultiProjectPagedRequest(MultiProjectRequest):
|
||||
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
pattern = fields.StringField()
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):
|
||||
@@ -98,3 +100,4 @@ class ProjectsGetRequest(models.Base):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
children_type = ActualEnumField(ProjectChildrenType)
|
||||
children_tags = fields.ListField(str)
|
||||
children_tags_filter = DictField()
|
||||
|
||||
@@ -333,3 +333,8 @@ class DeleteModelsRequest(TaskRequest):
|
||||
class GetAllReq(models.Base):
|
||||
allow_public = BoolField(default=True)
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateTagsRequest(BatchRequest):
|
||||
add_tags = ListField([str])
|
||||
remove_tags = ListField([str])
|
||||
|
||||
@@ -13,8 +13,7 @@ from jsonmodels.fields import (
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
from apiserver.config_repo import config
|
||||
|
||||
|
||||
class WorkerRequest(Base):
|
||||
@@ -24,7 +23,10 @@ class WorkerRequest(Base):
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min)
|
||||
timeout = IntField(
|
||||
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
|
||||
)
|
||||
""" registration timeout in seconds (default is 10min) """
|
||||
queues = ListField(six.string_types) # list of queues this worker listens to
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import zlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
import elasticsearch
|
||||
@@ -24,6 +23,7 @@ from apiserver.bll.event.event_common import (
|
||||
get_metric_variants_condition,
|
||||
uncompress_plot,
|
||||
get_max_metric_and_variant_counts,
|
||||
PlotFields,
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
|
||||
@@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task.utils import get_many_tasks_for_writing
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -42,26 +43,23 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
EVENT_TYPES: Set[str] = set(et.value for et in EventType if et != EventType.all)
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
MAX_LONG = 2 ** 63 - 1
|
||||
MIN_LONG = -(2 ** 63)
|
||||
MAX_LONG = 2**63 - 1
|
||||
MIN_LONG = -(2**63)
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
async_delete_threshold = config.get(
|
||||
"services.tasks.async_events_delete_threshold", 100_000
|
||||
)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
@@ -103,7 +101,9 @@ class EventBLL(object):
|
||||
return self._metrics
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
|
||||
def _get_valid_entities(
|
||||
company_id, ids: Mapping[str, bool], identity: Identity, model=False
|
||||
) -> Set:
|
||||
"""Verify that task or model exists and can be updated"""
|
||||
if not ids:
|
||||
return set()
|
||||
@@ -122,16 +122,34 @@ class EventBLL(object):
|
||||
):
|
||||
if not requested_ids:
|
||||
continue
|
||||
query = Q(id__in=requested_ids, company=company_id)
|
||||
res.update(
|
||||
(Model if model else Task).objects(query & locked_q).scalar("id")
|
||||
)
|
||||
|
||||
query = Q(id__in=requested_ids) & locked_q
|
||||
if model:
|
||||
ids = Model.objects(query & Q(company=company_id)).scalar("id")
|
||||
else:
|
||||
ids = {
|
||||
t.id
|
||||
for t in get_many_tasks_for_writing(
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
query=query,
|
||||
only=("id",),
|
||||
throw_on_forbidden=False,
|
||||
)
|
||||
}
|
||||
|
||||
res.update(ids)
|
||||
|
||||
return res
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker
|
||||
self,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
events: Sequence[dict],
|
||||
worker: str,
|
||||
) -> Tuple[int, int, dict]:
|
||||
user_id = identity.user
|
||||
task_ids = {}
|
||||
model_ids = {}
|
||||
for event in events:
|
||||
@@ -163,8 +181,12 @@ class EventBLL(object):
|
||||
"Inconsistent model_event setting in the passed events",
|
||||
tasks=found_in_both,
|
||||
)
|
||||
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
|
||||
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
|
||||
valid_models = self._get_valid_entities(
|
||||
company_id, ids=model_ids, identity=identity, model=True
|
||||
)
|
||||
valid_tasks = self._get_valid_entities(
|
||||
company_id, ids=task_ids, identity=identity
|
||||
)
|
||||
|
||||
actions: List[dict] = []
|
||||
used_task_ids = set()
|
||||
@@ -268,11 +290,13 @@ class EventBLL(object):
|
||||
else:
|
||||
used_task_ids.add(task_or_model_id)
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_or_model_id], event=event,
|
||||
last_events=task_last_events[task_or_model_id],
|
||||
event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id], event=event,
|
||||
last_events=task_last_scalar_events[task_or_model_id],
|
||||
event=event,
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
@@ -311,20 +335,23 @@ class EventBLL(object):
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
now = datetime.utcnow()
|
||||
for model_id in used_model_ids:
|
||||
ModelBLL.update_statistics(
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
model_id=model_id,
|
||||
last_update=now,
|
||||
last_iteration_max=task_iteration.get(model_id),
|
||||
last_scalar_events=task_last_scalar_events.get(model_id),
|
||||
)
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in used_task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
@@ -336,7 +363,12 @@ class EventBLL(object):
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
last_update=now,
|
||||
)
|
||||
|
||||
# this is for backwards compatibility with streaming bulk throwing exception on those
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
@@ -466,9 +498,10 @@ class EventBLL(object):
|
||||
|
||||
def _update_task(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
now,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
now: datetime,
|
||||
iter_max=None,
|
||||
last_scalar_events=None,
|
||||
last_events=None,
|
||||
@@ -484,8 +517,9 @@ class EventBLL(object):
|
||||
return False
|
||||
|
||||
return TaskBLL.update_statistics(
|
||||
task_id,
|
||||
company_id,
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
last_update=now,
|
||||
last_iteration_max=iter_max,
|
||||
last_scalar_events=last_scalar_events,
|
||||
@@ -569,7 +603,8 @@ class EventBLL(object):
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // last_iterations_per_plot)
|
||||
|
||||
@@ -636,9 +671,11 @@ class EventBLL(object):
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_debug_image_urls(
|
||||
self, company_id: str, task_id: str, after_key: dict = None
|
||||
self, company_id: str, task_ids: Sequence[str], after_key: dict = None
|
||||
) -> Tuple[Sequence[str], Optional[dict]]:
|
||||
if check_empty_data(self.es, company_id, EventType.metrics_image):
|
||||
if not task_ids or check_empty_data(
|
||||
self.es, company_id, EventType.metrics_image
|
||||
):
|
||||
return [], None
|
||||
|
||||
es_req = {
|
||||
@@ -654,7 +691,10 @@ class EventBLL(object):
|
||||
},
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
|
||||
"must": [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -672,9 +712,13 @@ class EventBLL(object):
|
||||
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
|
||||
|
||||
def get_plot_image_urls(
|
||||
self, company_id: str, task_id: str, scroll_id: Optional[str]
|
||||
self, company_id: str, task_ids: Sequence[str], scroll_id: Optional[str]
|
||||
) -> Tuple[Sequence[dict], Optional[str]]:
|
||||
if scroll_id == self.empty_scroll:
|
||||
if (
|
||||
scroll_id == self.empty_scroll
|
||||
or not task_ids
|
||||
or check_empty_data(self.es, company_id, EventType.metrics_plot)
|
||||
):
|
||||
return [], None
|
||||
|
||||
if scroll_id:
|
||||
@@ -689,7 +733,7 @@ class EventBLL(object):
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"task": task_ids}},
|
||||
{"exists": {"field": PlotFields.source_urls}},
|
||||
]
|
||||
}
|
||||
@@ -825,7 +869,8 @@ class EventBLL(object):
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
@@ -879,7 +924,8 @@ class EventBLL(object):
|
||||
}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
@@ -1023,9 +1069,9 @@ class EventBLL(object):
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
@@ -1091,7 +1137,10 @@ class EventBLL(object):
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_ids, event_type=event_type, body=es_req,
|
||||
self.es,
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
@@ -1142,18 +1191,26 @@ class EventBLL(object):
|
||||
|
||||
return {"refresh": True}
|
||||
|
||||
def delete_task_events(
|
||||
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
|
||||
):
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
|
||||
if model:
|
||||
self._validate_model_state(
|
||||
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
|
||||
company_id=company_id,
|
||||
model_id=task_id,
|
||||
allow_locked=allow_locked,
|
||||
)
|
||||
else:
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
|
||||
async_delete = async_task_events_delete
|
||||
if async_delete:
|
||||
total = self.events_iterator.count_task_events(
|
||||
event_type=EventType.all,
|
||||
company_id=company_id,
|
||||
task_ids=[task_id],
|
||||
)
|
||||
if total <= async_delete_threshold:
|
||||
async_delete = False
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context():
|
||||
es_res = delete_company_events(
|
||||
@@ -1211,14 +1268,23 @@ class EventBLL(object):
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def delete_multi_task_events(
|
||||
self, company_id: str, task_ids: Sequence[str], async_delete=False
|
||||
self, company_id: str, task_ids: Sequence[str], model=False
|
||||
):
|
||||
"""
|
||||
Delete mutliple task events. No check is done for tasks write access
|
||||
Delete multiple task events. No check is done for tasks write access
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
deleted = 0
|
||||
with translate_errors_context():
|
||||
async_delete = async_task_events_delete
|
||||
if async_delete and len(task_ids) < 100:
|
||||
total = self.events_iterator.count_task_events(
|
||||
event_type=EventType.all,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
if total <= async_delete_threshold:
|
||||
async_delete = False
|
||||
for tasks in chunked_iter(task_ids, 100):
|
||||
es_req = {"query": {"terms": {"task": tasks}}}
|
||||
es_res = delete_company_events(
|
||||
@@ -1232,7 +1298,7 @@ class EventBLL(object):
|
||||
deleted += es_res.get("deleted", 0)
|
||||
|
||||
if not async_delete:
|
||||
return es_res.get("deleted", 0)
|
||||
return deleted
|
||||
|
||||
def clear_scroll(self, scroll_id: str):
|
||||
if scroll_id == self.empty_scroll:
|
||||
|
||||
@@ -21,6 +21,7 @@ from apiserver.bll.event.event_common import (
|
||||
TaskCompanies,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.tools import safe_get
|
||||
@@ -161,7 +162,9 @@ class EventMetrics:
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, companies: TaskCompanies
|
||||
self,
|
||||
companies: TaskCompanies,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
@@ -179,7 +182,13 @@ class EventMetrics:
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_events = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(self._get_task_single_value_metrics, companies.items())
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_single_value_metrics,
|
||||
metric_variants=metric_variants,
|
||||
),
|
||||
companies.items(),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -195,19 +204,19 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, tasks: Tuple[str, Sequence[str]]
|
||||
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
|
||||
) -> Sequence[dict]:
|
||||
company_id, task_ids = tasks
|
||||
must = [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
@@ -280,7 +289,8 @@ class EventMetrics:
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
@@ -366,7 +376,8 @@ class EventMetrics:
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
@@ -432,7 +443,9 @@ class EventMetrics:
|
||||
|
||||
@classmethod
|
||||
def _get_task_metrics_query(
|
||||
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
|
||||
cls,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
):
|
||||
must = cls._task_conditions(task_id)
|
||||
if metrics:
|
||||
@@ -451,12 +464,96 @@ class EventMetrics:
|
||||
|
||||
return {"bool": {"must": must}}
|
||||
|
||||
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
|
||||
"""
|
||||
For the requested tasks return reported metrics and variants
|
||||
"""
|
||||
tasks_ids = {
|
||||
company: [t.id for t in tasks]
|
||||
for company, tasks in companies.items()
|
||||
}
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
companies_res: Sequence = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_multi_task_metrics,
|
||||
event_type=event_type,
|
||||
),
|
||||
tasks_ids.items(),
|
||||
)
|
||||
)
|
||||
|
||||
if len(companies_res) == 1:
|
||||
return companies_res[0]
|
||||
|
||||
res = defaultdict(set)
|
||||
for c_res in companies_res:
|
||||
for m, vars_ in c_res.items():
|
||||
res[m].update(vars_)
|
||||
|
||||
return {
|
||||
k: list(v)
|
||||
for k, v in res.items()
|
||||
}
|
||||
|
||||
def _get_multi_task_metrics(
|
||||
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
|
||||
) -> Mapping[str, list]:
|
||||
company_id, task_ids = company_tasks
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
query = QueryBuilder.terms("task", task_ids)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query,
|
||||
**search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
**search_args,
|
||||
)
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
return {
|
||||
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
|
||||
for mb in aggs_result["metrics"]["buckets"]
|
||||
}
|
||||
|
||||
def get_task_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
For the requested tasks return reported metrics per task
|
||||
"""
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
@@ -64,13 +64,13 @@ class EventsIterator:
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
task_ids: Sequence[str],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return 0
|
||||
|
||||
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
query, _ = self._get_initial_query_and_must(task_ids, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
}
|
||||
@@ -100,7 +100,7 @@ class EventsIterator:
|
||||
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
|
||||
so that events with this value will not be lost between the calls.
|
||||
"""
|
||||
query, must = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
query, must = self._get_initial_query_and_must([task_id], metric_variants)
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
@@ -158,14 +158,14 @@ class EventsIterator:
|
||||
|
||||
@staticmethod
|
||||
def _get_initial_query_and_must(
|
||||
task_id: str, metric_variants: MetricVariants = None
|
||||
task_ids: Sequence[str], metric_variants: MetricVariants = None
|
||||
) -> Tuple[dict, list]:
|
||||
if not metric_variants:
|
||||
must = [{"term": {"task": task_id}}]
|
||||
query = {"term": {"task": task_id}}
|
||||
query = {"terms": {"task": task_ids}}
|
||||
must = [query]
|
||||
else:
|
||||
must = [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"task": task_ids}},
|
||||
get_metric_variants_condition(metric_variants),
|
||||
]
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
@@ -10,6 +10,7 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
@@ -57,14 +58,15 @@ class ModelBLL:
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, str, bool], dict] = None,
|
||||
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
||||
|
||||
user_id = identity.user
|
||||
published_task = None
|
||||
if model.task and publish_task_func:
|
||||
task = (
|
||||
@@ -74,13 +76,20 @@ class ModelBLL:
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, user_id, force_publish_task
|
||||
model.task, company_id, identity, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
)
|
||||
|
||||
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
|
||||
now = datetime.utcnow()
|
||||
updated = model.update(
|
||||
upsert=False,
|
||||
ready=True,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
return updated, published_task
|
||||
|
||||
@classmethod
|
||||
@@ -125,6 +134,7 @@ class ModelBLL:
|
||||
"models.output.$[elem].model": deleted_model_id,
|
||||
"output.error": f"model deleted on {now.isoformat()}",
|
||||
"last_change": now,
|
||||
"last_changed_by": user_id,
|
||||
},
|
||||
},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
@@ -132,7 +142,9 @@ class ModelBLL:
|
||||
)
|
||||
else:
|
||||
task.update(
|
||||
pull__models__output__model=model_id, set__last_change=now
|
||||
pull__models__output__model=model_id,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user_id,
|
||||
)
|
||||
|
||||
delete_external_artifacts = delete_external_artifacts and config.get(
|
||||
@@ -167,25 +179,29 @@ class ModelBLL:
|
||||
return del_count, model
|
||||
|
||||
@classmethod
|
||||
def archive_model(cls, model_id: str, company_id: str):
|
||||
def archive_model(cls, model_id: str, company_id: str, user_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
archived = Model.objects(company=company_id, id=model_id).update(
|
||||
add_to_set__system_tags=EntityVisibility.archived.value,
|
||||
last_update=datetime.utcnow(),
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
return archived
|
||||
|
||||
@classmethod
|
||||
def unarchive_model(cls, model_id: str, company_id: str):
|
||||
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
|
||||
cls.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
unarchived = Model.objects(company=company_id, id=model_id).update(
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_update=datetime.utcnow(),
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
|
||||
return unarchived
|
||||
@@ -218,11 +234,18 @@ class ModelBLL:
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
last_update: datetime = None,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
):
|
||||
updates = {"last_update": datetime.utcnow()}
|
||||
last_update = last_update or datetime.utcnow()
|
||||
updates = {
|
||||
"last_update": datetime.utcnow(),
|
||||
"last_change": last_update,
|
||||
"last_changed_by": user_id,
|
||||
}
|
||||
if last_iteration_max is not None:
|
||||
updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Sequence, Dict
|
||||
from typing import Sequence, Dict, Type
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import AttributedDocument
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
@@ -22,6 +25,51 @@ class OrgBLL:
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def edit_entity_tags(
|
||||
self,
|
||||
company_id,
|
||||
entity_cls: Type[AttributedDocument],
|
||||
entity_ids: Sequence[str],
|
||||
add_tags: Sequence[str],
|
||||
remove_tags: Sequence[str],
|
||||
) -> int:
|
||||
if entity_cls not in (Task, Model):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Tags editing can be called on tasks or models only"
|
||||
)
|
||||
if not entity_ids:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"No entity ids provided for editing tags"
|
||||
)
|
||||
if not (add_tags or remove_tags):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either add tags or remove tags should be provided"
|
||||
)
|
||||
|
||||
updated = 0
|
||||
if add_tags:
|
||||
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
|
||||
add_to_set__tags=add_tags
|
||||
)
|
||||
if remove_tags:
|
||||
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
|
||||
pull_all__tags=remove_tags
|
||||
)
|
||||
if not updated:
|
||||
return 0
|
||||
|
||||
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
|
||||
"project"
|
||||
)
|
||||
update_project_time(project_ids=projects)
|
||||
self.update_tags(
|
||||
company_id,
|
||||
entity=Tags.Task if entity_cls is Task else Tags.Model,
|
||||
projects=projects,
|
||||
tags=add_tags or remove_tags
|
||||
)
|
||||
return updated
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company_id: str,
|
||||
@@ -50,10 +98,10 @@ class OrgBLL:
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company_id, project, tags, system_tags)
|
||||
tags_cache.update_tags(company_id, projects, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
|
||||
@@ -107,7 +107,7 @@ class _TagsCache:
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
|
||||
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
@@ -123,7 +123,7 @@ class _TagsCache:
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company_id, projects=[project], fields=fields)
|
||||
self._delete_redis_keys(company_id, projects=projects, fields=fields)
|
||||
|
||||
def reset_tags(self, company_id: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
|
||||
@@ -315,11 +315,12 @@ class ProjectBLL:
|
||||
description="",
|
||||
)
|
||||
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
extra = {}
|
||||
if hasattr(entity_cls, "last_change"):
|
||||
extra["set__last_change"] = datetime.utcnow()
|
||||
if hasattr(entity_cls, "last_changed_by"):
|
||||
extra["set__last_changed_by"] = user
|
||||
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
@@ -340,6 +341,17 @@ class ProjectBLL:
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def project_task_fields():
|
||||
return {
|
||||
"$project": {
|
||||
"project": 1,
|
||||
"status": 1,
|
||||
"system_tags": 1,
|
||||
"started": 1,
|
||||
"completed": 1,
|
||||
}
|
||||
}
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
@@ -367,6 +379,7 @@ class ProjectBLL:
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
project_task_fields(),
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
@@ -515,6 +528,7 @@ class ProjectBLL:
|
||||
users=users,
|
||||
)
|
||||
},
|
||||
project_task_fields(),
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
@@ -550,7 +564,10 @@ class ProjectBLL:
|
||||
|
||||
@classmethod
|
||||
def get_dataset_stats(
|
||||
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
users: Sequence[str] = None,
|
||||
) -> Dict[str, dict]:
|
||||
if not project_ids:
|
||||
return {}
|
||||
@@ -584,7 +601,9 @@ class ProjectBLL:
|
||||
|
||||
@staticmethod
|
||||
def _get_projects_children(
|
||||
project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str],
|
||||
project_ids: Sequence[str],
|
||||
search_hidden: bool,
|
||||
allowed_ids: Sequence[str],
|
||||
) -> Tuple[ProjectsChildren, Set[str]]:
|
||||
child_projects = _get_sub_projects(
|
||||
project_ids,
|
||||
@@ -628,7 +647,9 @@ class ProjectBLL:
|
||||
project_ids_with_children = set(project_ids)
|
||||
if include_children:
|
||||
child_projects, children_ids = cls._get_projects_children(
|
||||
project_ids, search_hidden=True, allowed_ids=selected_project_ids,
|
||||
project_ids,
|
||||
search_hidden=True,
|
||||
allowed_ids=selected_project_ids,
|
||||
)
|
||||
project_ids_with_children |= children_ids
|
||||
|
||||
@@ -902,6 +923,7 @@ class ProjectBLL:
|
||||
allow_public: bool = True,
|
||||
children_type: ProjectChildrenType = None,
|
||||
children_tags: Sequence[str] = None,
|
||||
children_tags_filter: dict = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
|
||||
@@ -922,11 +944,15 @@ class ProjectBLL:
|
||||
query &= Q(user__in=users)
|
||||
|
||||
project_query = None
|
||||
child_query = (
|
||||
query & GetMixin.get_list_field_query("tags", children_tags)
|
||||
if children_tags
|
||||
else query
|
||||
)
|
||||
if children_tags_filter:
|
||||
child_query = query & GetMixin.get_list_filter_query(
|
||||
"tags", children_tags_filter
|
||||
)
|
||||
elif children_tags:
|
||||
child_query = query & GetMixin.get_list_field_query("tags", children_tags)
|
||||
else:
|
||||
child_query = query
|
||||
|
||||
if children_type == ProjectChildrenType.dataset:
|
||||
child_queries = {
|
||||
Project: child_query
|
||||
@@ -1086,39 +1112,50 @@ class ProjectBLL:
|
||||
|
||||
or_conditions = []
|
||||
for field, field_filter in filter_.items():
|
||||
if not (
|
||||
field_filter
|
||||
and isinstance(field_filter, list)
|
||||
and all(isinstance(t, str) for t in field_filter)
|
||||
):
|
||||
if not (field_filter and isinstance(field_filter, (list, dict))):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
f"Non empty list or dictionary expected for the field: {field}"
|
||||
)
|
||||
helper = GetMixin.NewListFieldBucketHelper(
|
||||
field, data=field_filter, legacy=True
|
||||
)
|
||||
field_conditions = {}
|
||||
for action, values in helper.actions.items():
|
||||
value = list(set(values))
|
||||
for key in reversed(action.split("__")):
|
||||
value = {f"${key}": value}
|
||||
field_conditions.update(value)
|
||||
if (
|
||||
helper.explicit_operator
|
||||
and helper.global_operator == Q.OR
|
||||
and len(field_conditions) > 1
|
||||
):
|
||||
or_conditions.append(
|
||||
[{field: {op: cond}} for op, cond in field_conditions.items()]
|
||||
|
||||
if isinstance(field_filter, list):
|
||||
if not all(isinstance(t, str) for t in field_filter):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Only string values are allowed in the list filter: {field}"
|
||||
)
|
||||
helper = GetMixin.NewListFieldBucketHelper(
|
||||
field, data=field_filter, legacy=True
|
||||
)
|
||||
op = helper.global_operator
|
||||
db_query = {op: helper.actions}
|
||||
else:
|
||||
conditions[field] = field_conditions
|
||||
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
|
||||
db_query = helper.db_query
|
||||
|
||||
for op, actions in db_query.items():
|
||||
field_conditions = {}
|
||||
for action, values in actions.items():
|
||||
value = list(set(values)) if isinstance(values, list) else values
|
||||
for key in reversed(action.split("__")):
|
||||
value = {f"${key}": value}
|
||||
field_conditions.update(value)
|
||||
|
||||
if op == Q.OR and len(field_conditions) > 1:
|
||||
or_conditions.append(
|
||||
{
|
||||
"$or": [
|
||||
{field: {db_modifier: cond}}
|
||||
for db_modifier, cond in field_conditions.items()
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
conditions[field] = field_conditions
|
||||
|
||||
if or_conditions:
|
||||
if len(or_conditions) == 1:
|
||||
conditions["$or"] = next(iter(or_conditions))
|
||||
conditions = next(iter(or_conditions))
|
||||
else:
|
||||
conditions["$and"] = [{"$or": c} for c in or_conditions]
|
||||
conditions["$and"] = [c for c in or_conditions]
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -83,7 +82,8 @@ def validate_project_delete(company: str, project_id: str):
|
||||
ret["pipelines"] = 0
|
||||
if dataset_ids:
|
||||
datasets_with_data = Task.objects(
|
||||
project__in=dataset_ids, system_tags__nin=[EntityVisibility.archived.value],
|
||||
project__in=dataset_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).distinct("project")
|
||||
ret["datasets"] = len(datasets_with_data)
|
||||
else:
|
||||
@@ -185,10 +185,10 @@ def delete_project(
|
||||
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
|
||||
else:
|
||||
deleted_models, model_event_urls, model_urls = _delete_models(
|
||||
company=company, projects=project_ids
|
||||
company=company, user=user, projects=project_ids
|
||||
)
|
||||
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
|
||||
company=company, projects=project_ids
|
||||
company=company, user=user, projects=project_ids
|
||||
)
|
||||
event_urls = task_event_urls | model_event_urls
|
||||
if delete_external_artifacts:
|
||||
@@ -217,7 +217,9 @@ def delete_project(
|
||||
return res, affected
|
||||
|
||||
|
||||
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
|
||||
def _delete_tasks(
|
||||
company: str, user: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set, Set]:
|
||||
"""
|
||||
Delete only the task themselves and their non published version.
|
||||
Child models under the same project are deleted separately.
|
||||
@@ -228,14 +230,24 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
|
||||
if not tasks:
|
||||
return 0, set(), set()
|
||||
|
||||
task_ids = {t.id for t in tasks}
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
|
||||
task_ids = list({t.id for t in tasks})
|
||||
now = datetime.utcnow()
|
||||
Task.objects(parent__in=task_ids, project__nin=projects).update(
|
||||
parent=None,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
Model.objects(task__in=task_ids, project__nin=projects).update(
|
||||
task=None,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
|
||||
event_urls, artifact_urls = set(), set()
|
||||
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
|
||||
company, task_ids
|
||||
)
|
||||
artifact_urls = set()
|
||||
for task in tasks:
|
||||
event_urls.update(collect_debug_image_urls(company, task.id))
|
||||
event_urls.update(collect_plot_image_urls(company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls.update(
|
||||
{
|
||||
@@ -245,15 +257,13 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
|
||||
}
|
||||
)
|
||||
|
||||
event_bll.delete_multi_task_events(
|
||||
company, list(task_ids), async_delete=async_events_delete
|
||||
)
|
||||
event_bll.delete_multi_task_events(company, task_ids)
|
||||
deleted = tasks.delete()
|
||||
return deleted, event_urls, artifact_urls
|
||||
|
||||
|
||||
def _delete_models(
|
||||
company: str, projects: Sequence[str]
|
||||
company: str, user: str, projects: Sequence[str]
|
||||
) -> Tuple[int, Set[str], Set[str]]:
|
||||
"""
|
||||
Delete project models and update the tasks from other projects
|
||||
@@ -287,25 +297,31 @@ def _delete_models(
|
||||
"status": TaskStatus.published,
|
||||
},
|
||||
update={
|
||||
"$set": {"models.output.$[elem].model": deleted, "last_change": now,}
|
||||
"$set": {
|
||||
"models.output.$[elem].model": deleted,
|
||||
"last_change": now,
|
||||
"last_changed_by": user,
|
||||
}
|
||||
},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
# update unpublished tasks
|
||||
Task.objects(
|
||||
id__in=model_tasks, project__nin=projects, status__ne=TaskStatus.published,
|
||||
).update(pull__models__output__model__in=model_ids, set__last_change=now)
|
||||
id__in=model_tasks,
|
||||
project__nin=projects,
|
||||
status__ne=TaskStatus.published,
|
||||
).update(
|
||||
pull__models__output__model__in=model_ids,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user,
|
||||
)
|
||||
|
||||
event_urls, model_urls = set(), set()
|
||||
for m in models:
|
||||
event_urls.update(collect_debug_image_urls(company, m.id))
|
||||
event_urls.update(collect_plot_image_urls(company, m.id))
|
||||
if m.uri:
|
||||
model_urls.add(m.uri)
|
||||
|
||||
event_bll.delete_multi_task_events(
|
||||
company, model_ids, async_delete=async_events_delete
|
||||
event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
|
||||
company, model_ids
|
||||
)
|
||||
model_urls = {m.uri for m in models if m.uri}
|
||||
|
||||
event_bll.delete_multi_task_events(company, model_ids, model=True)
|
||||
deleted = models.delete()
|
||||
return deleted, event_urls, model_urls
|
||||
|
||||
@@ -140,6 +140,7 @@ class ProjectQueries:
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
pattern: str = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> ParamValues:
|
||||
@@ -164,7 +165,20 @@ class ProjectQueries:
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}_{page}_{page_size}"
|
||||
redis_key = "_".join(
|
||||
str(part)
|
||||
for part in (
|
||||
"hyperparam_values",
|
||||
company_id,
|
||||
"_".join(project_ids),
|
||||
section,
|
||||
name,
|
||||
allow_public,
|
||||
pattern,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
)
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
@@ -176,14 +190,22 @@ class ProjectQueries:
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
match_condition = {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
if pattern:
|
||||
match_condition["$expr"] = {
|
||||
"$regexMatch": {
|
||||
"input": f"${key_path}.value",
|
||||
"regex": pattern,
|
||||
"options": "i",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
pipeline = [
|
||||
{"$match": match_condition},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
@@ -217,6 +239,7 @@ class ProjectQueries:
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
ids: Sequence[str],
|
||||
model_metrics: bool = False,
|
||||
):
|
||||
pipeline = [
|
||||
@@ -224,6 +247,7 @@ class ProjectQueries:
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
**({"_id": {"$in": ids}} if ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
|
||||
@@ -152,7 +152,7 @@ class QueueBLL(object):
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get_for_writing(
|
||||
task = Task.get(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=[
|
||||
|
||||
@@ -254,6 +254,14 @@ class StatisticsReporter:
|
||||
**({"last_worker": {"$in": workers}} if workers else {}),
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"last_worker": 1,
|
||||
"last_update": 1,
|
||||
"started": 1,
|
||||
"last_iteration": 1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$last_worker" if workers else None,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from .task_bll import TaskBLL
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
@@ -48,12 +49,14 @@ class Artifacts:
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
@@ -64,18 +67,20 @@ class Artifacts:
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
@@ -85,4 +90,4 @@ class Artifacts:
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@@ -31,7 +32,10 @@ class HyperParams:
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -63,7 +67,7 @@ class HyperParams:
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
@@ -74,6 +78,7 @@ class HyperParams:
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
@@ -96,7 +101,7 @@ class HyperParams:
|
||||
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
update_cmds=delete_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
@@ -105,7 +110,7 @@ class HyperParams:
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
@@ -117,6 +122,7 @@ class HyperParams:
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
@@ -135,7 +141,7 @@ class HyperParams:
|
||||
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
update_cmds=update_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
@@ -163,7 +169,10 @@ class HyperParams:
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -209,13 +218,15 @@ class HyperParams:
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
@@ -228,22 +239,24 @@ class HyperParams:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[str],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apiserver.bll.task import update_project_time
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import TaskStatus, Task
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
@@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by="__apiserver__",
|
||||
)
|
||||
if updated:
|
||||
project_ids.add(task.project)
|
||||
|
||||
@@ -12,6 +12,7 @@ from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -31,7 +32,10 @@ from apiserver.database.model.task.task import (
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
id as create_id,
|
||||
)
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
@@ -39,7 +43,6 @@ from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
get_last_metric_updates,
|
||||
)
|
||||
@@ -55,30 +58,13 @@ class TaskBLL:
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
only_fields=None,
|
||||
allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
@@ -313,7 +299,7 @@ class TaskBLL:
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
projects=[new_task.project],
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
@@ -356,6 +342,7 @@ class TaskBLL:
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
last_update: datetime,
|
||||
**extra_updates,
|
||||
):
|
||||
@@ -376,6 +363,7 @@ class TaskBLL:
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
last_changed_by=user_id,
|
||||
**updates,
|
||||
)
|
||||
return count
|
||||
@@ -384,6 +372,7 @@ class TaskBLL:
|
||||
def update_statistics(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
@@ -440,6 +429,7 @@ class TaskBLL:
|
||||
ret = TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
last_update=last_update,
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple
|
||||
from typing import Sequence, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition, bucketize, first
|
||||
from boltons.iterutils import partition, bucketize, first, chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import NotUniqueError
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
@@ -26,7 +26,6 @@ from apiserver.database.utils import id as db_id
|
||||
|
||||
log = config.logger(__file__)
|
||||
event_bll = EventBLL()
|
||||
async_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@@ -69,37 +68,47 @@ class CleanupResult:
|
||||
)
|
||||
|
||||
|
||||
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
def collect_plot_image_urls(
|
||||
company: str, task_or_model: Union[str, Sequence[str]]
|
||||
) -> Set[str]:
|
||||
urls = set()
|
||||
next_scroll_id = None
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
|
||||
for tasks in chunked_iter(task_ids, 100):
|
||||
next_scroll_id = None
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_ids=tasks, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
|
||||
def collect_debug_image_urls(
|
||||
company: str, task_or_model: Union[str, Sequence[str]]
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
"""
|
||||
after_key = None
|
||||
urls = set()
|
||||
while True:
|
||||
res, after_key = event_bll.get_debug_image_urls(
|
||||
company_id=company, task_id=task_or_model, after_key=after_key,
|
||||
)
|
||||
urls.update(res)
|
||||
if not after_key:
|
||||
break
|
||||
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
|
||||
for tasks in chunked_iter(task_ids, 100):
|
||||
after_key = None
|
||||
while True:
|
||||
res, after_key = event_bll.get_debug_image_urls(
|
||||
company_id=company,
|
||||
task_ids=tasks,
|
||||
after_key=after_key,
|
||||
)
|
||||
urls.update(res)
|
||||
if not after_key:
|
||||
break
|
||||
|
||||
return urls
|
||||
|
||||
@@ -122,7 +131,11 @@ supported_storage_types.update(
|
||||
|
||||
|
||||
def _schedule_for_delete(
|
||||
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
|
||||
company: str,
|
||||
user: str,
|
||||
task_id: str,
|
||||
urls: Set[str],
|
||||
can_delete_folders: bool,
|
||||
) -> Set[str]:
|
||||
urls_per_storage = bucketize(
|
||||
urls,
|
||||
@@ -222,8 +235,13 @@ def cleanup_task(
|
||||
|
||||
deleted_task_id = f"{deleted_prefix}{task.id}"
|
||||
updated_children = 0
|
||||
now = datetime.utcnow()
|
||||
if update_children:
|
||||
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
|
||||
updated_children = Task.objects(parent=task.id).update(
|
||||
parent=deleted_task_id,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
|
||||
deleted_models = 0
|
||||
updated_models = 0
|
||||
@@ -231,37 +249,41 @@ def cleanup_task(
|
||||
if not models:
|
||||
continue
|
||||
if delete_output_models and allow_delete:
|
||||
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
|
||||
for m_id in model_ids:
|
||||
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
|
||||
if model_ids:
|
||||
if return_file_urls or delete_external_artifacts:
|
||||
event_urls.update(collect_debug_image_urls(task.company, m_id))
|
||||
event_urls.update(collect_plot_image_urls(task.company, m_id))
|
||||
try:
|
||||
event_bll.delete_task_events(
|
||||
task.company,
|
||||
m_id,
|
||||
allow_locked=True,
|
||||
model=True,
|
||||
async_delete=async_events_delete,
|
||||
)
|
||||
except errors.bad_request.InvalidModelId as ex:
|
||||
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
|
||||
event_urls.update(collect_debug_image_urls(task.company, model_ids))
|
||||
event_urls.update(collect_plot_image_urls(task.company, model_ids))
|
||||
|
||||
event_bll.delete_multi_task_events(
|
||||
task.company,
|
||||
model_ids,
|
||||
model=True,
|
||||
)
|
||||
deleted_models += Model.objects(id__in=model_ids).delete()
|
||||
|
||||
deleted_models += Model.objects(id__in=list(model_ids)).delete()
|
||||
if in_use_model_ids:
|
||||
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
|
||||
Model.objects(id__in=list(in_use_model_ids)).update(
|
||||
unset__task=1,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user,
|
||||
)
|
||||
continue
|
||||
|
||||
if update_children:
|
||||
updated_models += Model.objects(id__in=[m.id for m in models]).update(
|
||||
task=deleted_task_id
|
||||
task=deleted_task_id,
|
||||
last_change=now,
|
||||
last_changed_by=user,
|
||||
)
|
||||
else:
|
||||
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
|
||||
Model.objects(id__in=[m.id for m in models]).update(
|
||||
unset__task=1,
|
||||
set__last_change=now,
|
||||
set__last_changed_by=user,
|
||||
)
|
||||
|
||||
event_bll.delete_task_events(
|
||||
task.company, task.id, allow_locked=force, async_delete=async_events_delete
|
||||
)
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
|
||||
if delete_external_artifacts:
|
||||
scheduled = _schedule_for_delete(
|
||||
@@ -304,7 +326,8 @@ def verify_task_children_and_ouptuts(
|
||||
|
||||
model_fields = ["id", "ready", "uri"]
|
||||
published_models, draft_models = partition(
|
||||
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
|
||||
Model.objects(task=task.id).only(*model_fields),
|
||||
key=attrgetter("ready"),
|
||||
)
|
||||
if not force and published_models:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
|
||||
@@ -7,9 +7,10 @@ from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
validate_status_change,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -24,6 +25,7 @@ from apiserver.database.model.task.task import (
|
||||
DEFAULT_LAST_ITERATION,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -33,7 +35,7 @@ queue_bll = QueueBLL()
|
||||
def archive_task(
|
||||
task: Union[str, Task],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
@@ -42,9 +44,10 @@ def archive_task(
|
||||
Return 1 if successful
|
||||
"""
|
||||
if isinstance(task, str):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
@@ -54,8 +57,9 @@ def archive_task(
|
||||
"system_tags",
|
||||
"enqueue_status",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
@@ -79,27 +83,34 @@ def archive_task(
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
"""
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task, company_id=company_id, only=("id",), requires_write_access=True,
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id",),
|
||||
)
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=user_id,
|
||||
last_changed_by=identity.user,
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
remove_from_all_queues: bool = False,
|
||||
@@ -112,7 +123,19 @@ def dequeue_task(
|
||||
task = Task.get(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=(
|
||||
_only=("id",),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
"execution",
|
||||
@@ -120,11 +143,7 @@ def dequeue_task(
|
||||
"project",
|
||||
"enqueue_status",
|
||||
),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
@@ -141,7 +160,7 @@ def dequeue_task(
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
@@ -166,11 +185,11 @@ def enqueue_task(
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
if validate:
|
||||
TaskBLL.validate(task)
|
||||
|
||||
@@ -200,9 +219,9 @@ def enqueue_task(
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
||||
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
|
||||
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
|
||||
|
||||
nested_set(res, ("fields", "execution.queue"), queue_id)
|
||||
return 1, res
|
||||
@@ -235,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
@@ -244,8 +263,9 @@ def delete_task(
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -298,15 +318,16 @@ def delete_task(
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
@@ -345,11 +366,17 @@ def reset_task(
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
unset__started=1,
|
||||
unset__completed=1,
|
||||
unset__published=1,
|
||||
unset__active_duration=1,
|
||||
unset__enqueue_status=1,
|
||||
)
|
||||
|
||||
if clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(), unset__script=1,
|
||||
set__execution=Execution(),
|
||||
unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
@@ -370,11 +397,6 @@ def reset_task(
|
||||
status_message="reset",
|
||||
user_id=user_id,
|
||||
).execute(
|
||||
started=None,
|
||||
completed=None,
|
||||
published=None,
|
||||
active_duration=None,
|
||||
enqueue_status=None,
|
||||
**updates,
|
||||
)
|
||||
|
||||
@@ -384,14 +406,15 @@ def reset_task(
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str, str], Any] = None,
|
||||
publish_model_func: Callable[[str, str, Identity], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
@@ -414,7 +437,7 @@ def publish_task(
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id, user_id)
|
||||
publish_model_func(model.id, company_id, identity)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
@@ -438,7 +461,7 @@ def publish_task(
|
||||
def stop_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
@@ -451,10 +474,11 @@ def stop_task(
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
@@ -464,7 +488,6 @@ def stop_task(
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence, Union
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
@@ -158,25 +160,75 @@ def get_possible_status_changes(current_status):
|
||||
return possible
|
||||
|
||||
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
def get_many_tasks_for_writing(
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
query: Q = None,
|
||||
only: Sequence = None,
|
||||
throw_on_forbidden: bool = True,
|
||||
) -> Sequence[Task]:
|
||||
if only:
|
||||
missing = [f for f in ("company", ) if f not in only]
|
||||
if missing:
|
||||
only = [*only, *missing]
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
result = list(
|
||||
Task.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
override_projection=only,
|
||||
allow_public=True,
|
||||
return_dicts=False,
|
||||
)
|
||||
)
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
forbidden_tasks = {task.id for task in result if not task.company}
|
||||
if forbidden_tasks:
|
||||
if throw_on_forbidden:
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
|
||||
)
|
||||
result = [task for task in result if task.id not in forbidden_tasks]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_task_with_write_access(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
only=None,
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
query = dict(id=task_id, company=company_id)
|
||||
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
identity: Identity,
|
||||
allow_all_statuses: bool = False,
|
||||
force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
only=("id", "status"),
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
@@ -8,11 +9,13 @@ from typing import (
|
||||
Tuple,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.settings import Settings
|
||||
|
||||
|
||||
@@ -77,3 +80,13 @@ def run_batch_operation(
|
||||
}
|
||||
)
|
||||
return results, failures
|
||||
|
||||
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
@@ -5,13 +5,13 @@ from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
from boltons.iterutils import partition
|
||||
from boltons.iterutils import partition, chunked_iter
|
||||
from pyhocon import ConfigTree
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.errors import bad_request, server_error
|
||||
from apiserver.apimodels.workers import (
|
||||
DEFAULT_TIMEOUT,
|
||||
IdNameEntry,
|
||||
WorkerEntry,
|
||||
StatusReportRequest,
|
||||
@@ -30,12 +30,14 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.tools import safe_get
|
||||
from .stats import WorkerStats
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.config = config.get("services.workers", ConfigTree())
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@@ -68,7 +70,7 @@ class WorkerBLL:
|
||||
"""
|
||||
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
|
||||
|
||||
timeout = timeout or DEFAULT_TIMEOUT
|
||||
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
|
||||
queues = queues or []
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -141,8 +143,6 @@ class WorkerBLL:
|
||||
|
||||
try:
|
||||
entry.ip = ip
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
@@ -150,15 +150,16 @@ class WorkerBLL:
|
||||
entry.system_tags = system_tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
self.log_stats_to_es(
|
||||
company_id=company_id,
|
||||
company_name=entry.company.name,
|
||||
worker=entry.key,
|
||||
worker_id=report.worker,
|
||||
timestamp=report.timestamp,
|
||||
task=report.task,
|
||||
machine_stats=report.machine_stats,
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
entry.last_activity_time = now
|
||||
entry.queue = report.queue
|
||||
|
||||
if report.queues:
|
||||
@@ -175,6 +176,7 @@ class WorkerBLL:
|
||||
last_worker_report=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=user_id,
|
||||
)
|
||||
# modify(new=True, ...) returns the modified object
|
||||
task = Task.objects(**query).modify(new=True, **update)
|
||||
@@ -253,18 +255,15 @@ class WorkerBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
),
|
||||
helpers = [
|
||||
WorkerConversionHelper.from_worker_entry(entry)
|
||||
for entry in self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
|
||||
all_queues = set(
|
||||
@@ -283,9 +282,7 @@ class WorkerBLL:
|
||||
}
|
||||
},
|
||||
]
|
||||
queues_info = {
|
||||
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||
}
|
||||
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
|
||||
task_ids = task_ids.union(
|
||||
filter(
|
||||
None,
|
||||
@@ -495,12 +492,15 @@ class WorkerBLL:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
|
||||
entries = []
|
||||
for key in self._get_keys(
|
||||
company, user=user, user_tags=user_tags, system_tags=system_tags
|
||||
for keys in chunked_iter(
|
||||
self._get_keys(
|
||||
company, user=user, user_tags=user_tags, system_tags=system_tags
|
||||
),
|
||||
1000,
|
||||
):
|
||||
data = self.redis.get(key)
|
||||
data = self.redis.mget(keys)
|
||||
if data:
|
||||
entries.append(WorkerEntry.from_json(data))
|
||||
entries.extend(WorkerEntry.from_json(d) for d in data if d)
|
||||
|
||||
return entries
|
||||
|
||||
@@ -509,18 +509,17 @@ class WorkerBLL:
|
||||
"""Get the index name suffix for storing current month data"""
|
||||
return datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
def _log_stats_to_es(
|
||||
def log_stats_to_es(
|
||||
self,
|
||||
company_id: str,
|
||||
company_name: str,
|
||||
worker: str,
|
||||
worker_id: str,
|
||||
timestamp: int,
|
||||
task: str,
|
||||
machine_stats: MachineStats,
|
||||
) -> bool:
|
||||
) -> int:
|
||||
"""
|
||||
Actually writing the worker statistics to Elastic
|
||||
:return: True if successful, False otherwise
|
||||
:return: The amount of logged documents
|
||||
"""
|
||||
es_index = (
|
||||
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
|
||||
@@ -532,8 +531,7 @@ class WorkerBLL:
|
||||
_index=es_index,
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
company=company_name,
|
||||
worker=worker_id,
|
||||
task=task,
|
||||
category=category,
|
||||
metric=metric,
|
||||
@@ -558,7 +556,7 @@ class WorkerBLL:
|
||||
|
||||
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
|
||||
added, errors = es_res[:2]
|
||||
return (added == len(actions)) and not errors
|
||||
return added
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
|
||||
@@ -215,6 +215,10 @@ class WorkerStats:
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"extended_bounds": {
|
||||
"min": int(from_date) * 1000,
|
||||
"max": int(to_date) * 1000,
|
||||
}
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
|
||||
@@ -23,4 +23,6 @@ hyperparam_values {
|
||||
max_last_metrics: 2000
|
||||
|
||||
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
|
||||
async_events_delete: false
|
||||
async_events_delete: true
|
||||
# do not use async_delete if the deleted task has amount of events lower than this threshold
|
||||
async_events_delete_threshold: 100000
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from collections import namedtuple, defaultdict
|
||||
from datetime import datetime
|
||||
from functools import reduce, partial
|
||||
from typing import (
|
||||
Collection,
|
||||
@@ -145,9 +146,10 @@ class GetMixin(PropsMixin):
|
||||
"__$any": Q.OR,
|
||||
"__$or": Q.OR,
|
||||
}
|
||||
default_operator = Q.OR
|
||||
default_global_operator = Q.AND
|
||||
default_context = Q.OR
|
||||
# not_all modifier currently not supported due to the backwards compatibility
|
||||
mongo_modifiers = {
|
||||
# not_all modifier currently not supported due to the backwards compatibility
|
||||
Q.AND: {True: "all", False: "nin"},
|
||||
Q.OR: {True: "in", False: "nin"},
|
||||
}
|
||||
@@ -164,24 +166,22 @@ class GetMixin(PropsMixin):
|
||||
self.allow_empty = False
|
||||
self.global_operator = None
|
||||
self.actions = defaultdict(list)
|
||||
self.explicit_operator = False
|
||||
|
||||
self._support_legacy = legacy
|
||||
current_context = self.default_operator
|
||||
current_context = self.default_context
|
||||
for d in self._get_next_term(data):
|
||||
if d.operator is not None:
|
||||
current_context = d.operator
|
||||
self._support_legacy = False
|
||||
if self.global_operator is None:
|
||||
self.global_operator = d.operator
|
||||
self.explicit_operator = True
|
||||
continue
|
||||
|
||||
if self.global_operator is None:
|
||||
self.global_operator = self.default_operator
|
||||
self.global_operator = self.default_global_operator
|
||||
|
||||
if d.reset:
|
||||
current_context = self.default_operator
|
||||
current_context = self.default_context
|
||||
self._support_legacy = legacy
|
||||
continue
|
||||
|
||||
@@ -194,11 +194,9 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
|
||||
if self.global_operator is None:
|
||||
self.global_operator = self.default_operator
|
||||
self.global_operator = self.default_global_operator
|
||||
|
||||
def _get_next_term(
|
||||
self, data: Sequence[str]
|
||||
) -> Generator[Term, None, None]:
|
||||
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
|
||||
unary_operator = None
|
||||
for value in data:
|
||||
if value is None:
|
||||
@@ -232,12 +230,18 @@ class GetMixin(PropsMixin):
|
||||
operator = self._operators.get(value)
|
||||
if operator is None:
|
||||
raise FieldsValueError(
|
||||
"Unsupported operator", field=self._field, operator=value,
|
||||
"Unsupported operator",
|
||||
field=self._field,
|
||||
operator=value,
|
||||
)
|
||||
yield self.Term(operator=operator)
|
||||
continue
|
||||
|
||||
if not unary_operator and self._support_legacy and value.startswith("-"):
|
||||
if (
|
||||
not unary_operator
|
||||
and self._support_legacy
|
||||
and value.startswith("-")
|
||||
):
|
||||
value = value[1:]
|
||||
if not value:
|
||||
raise FieldsValueError(
|
||||
@@ -402,12 +406,25 @@ class GetMixin(PropsMixin):
|
||||
parameters = {
|
||||
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
|
||||
}
|
||||
filters = parameters.pop("filters", {})
|
||||
if not isinstance(filters, dict):
|
||||
raise FieldsValueError(
|
||||
"invalid value type, string expected",
|
||||
field=filters,
|
||||
value=str(filters),
|
||||
)
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.list_fields, parameters=filters
|
||||
).items():
|
||||
query &= cls.get_list_filter_query(field, data)
|
||||
parameters.pop(field, None)
|
||||
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.list_fields, parameters=parameters
|
||||
).items():
|
||||
@@ -531,6 +548,149 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return q
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class ListQueryFilter:
|
||||
"""
|
||||
Deserialize filters data and build db_query object that represents it with the corresponding
|
||||
mongo engine operations
|
||||
Each part has include and exclude lists that map to mongoengine operations as following:
|
||||
"any"
|
||||
- include -> 'in'
|
||||
- exclude -> 'not_all'
|
||||
- combined by 'or' operation
|
||||
"all"
|
||||
- include -> 'all'
|
||||
- exclude -> 'nin'
|
||||
- combined by 'and' operation
|
||||
"op" optional parameter for combining "and" and "all" parts. Can be "and" or "or". The default is "and"
|
||||
"""
|
||||
|
||||
_and_op = "and"
|
||||
_or_op = "or"
|
||||
_allowed_op = [_and_op, _or_op]
|
||||
_db_modifiers: Mapping = {
|
||||
(Q.OR, True): "in",
|
||||
(Q.OR, False): "not__all",
|
||||
(Q.AND, True): "all",
|
||||
(Q.AND, False): "nin",
|
||||
}
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class ListFilter:
|
||||
include: Sequence[str] = []
|
||||
exclude: Sequence[str] = []
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Mapping):
|
||||
if d is None:
|
||||
return None
|
||||
return cls(**d)
|
||||
|
||||
any: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
|
||||
all: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
|
||||
op: str = attr.ib(default="and")
|
||||
db_query: dict = attr.ib(init=False)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
@op.validator
|
||||
def op_validator(self, _, value):
|
||||
if value not in self._allowed_op:
|
||||
raise ValueError(
|
||||
f"Invalid list query filter operator: {value}. "
|
||||
f"Should be one of {str(self._allowed_op)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def and_op(self) -> bool:
|
||||
return self.op == self._and_op
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.db_query = {}
|
||||
for op, conditions in ((Q.OR, self.any), (Q.AND, self.all)):
|
||||
if not conditions:
|
||||
continue
|
||||
|
||||
operations = {}
|
||||
for vals, include in (
|
||||
(conditions.include, True),
|
||||
(conditions.exclude, False),
|
||||
):
|
||||
if not vals:
|
||||
continue
|
||||
|
||||
unique = set(vals)
|
||||
if None in unique:
|
||||
# noinspection PyTypeChecker
|
||||
unique.remove(None)
|
||||
if include:
|
||||
operations["size"] = 0
|
||||
else:
|
||||
operations["not__size"] = 0
|
||||
|
||||
if not unique:
|
||||
continue
|
||||
|
||||
operations[self._db_modifiers[(op, include)]] = list(unique)
|
||||
|
||||
self.db_query[op] = operations
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, field, data: Mapping):
|
||||
if not isinstance(data, dict):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"invalid filter for field, dictionary expected",
|
||||
field=field,
|
||||
value=str(data),
|
||||
)
|
||||
|
||||
try:
|
||||
return cls(**data)
|
||||
except Exception as ex:
|
||||
raise errors.bad_request.ValidationError(
|
||||
field=field,
|
||||
value=str(ex),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_list_filter_query(
|
||||
cls, field: str, data: Mapping
|
||||
) -> Union[RegexQ, RegexQCombination]:
|
||||
if not data:
|
||||
return RegexQ()
|
||||
|
||||
filter_ = cls.ListQueryFilter.from_data(field, data)
|
||||
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
queries = []
|
||||
for op, actions in filter_.db_query.items():
|
||||
if not actions:
|
||||
continue
|
||||
|
||||
ops = []
|
||||
for action, vals in actions.items():
|
||||
# cannot just check vals here since 0 is acceptable value
|
||||
if vals is None or vals == []:
|
||||
continue
|
||||
|
||||
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
|
||||
|
||||
if not ops:
|
||||
continue
|
||||
|
||||
if len(ops) == 1:
|
||||
queries.extend(ops)
|
||||
continue
|
||||
|
||||
queries.append(RegexQCombination(operation=op, children=ops))
|
||||
|
||||
if not queries:
|
||||
return RegexQ()
|
||||
if len(queries) == 1:
|
||||
return queries[0]
|
||||
|
||||
operation = Q.AND if filter_.and_op else Q.OR
|
||||
return RegexQCombination(operation=operation, children=queries)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
@@ -639,7 +799,7 @@ class GetMixin(PropsMixin):
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
""" Extract a projection list from the provided dictionary. Supports an override projection. """
|
||||
"""Extract a projection list from the provided dictionary. Supports an override projection."""
|
||||
if override_projection is not None:
|
||||
return override_projection
|
||||
if not parameters:
|
||||
@@ -653,7 +813,8 @@ class GetMixin(PropsMixin):
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
projection,
|
||||
key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
@@ -900,7 +1061,9 @@ class GetMixin(PropsMixin):
|
||||
projection_fields=projection_fields,
|
||||
)
|
||||
return cls.get_data_with_scroll_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
query_dict=query_dict,
|
||||
data_getter=data_getter,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
@@ -913,7 +1076,9 @@ class GetMixin(PropsMixin):
|
||||
|
||||
@classmethod
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
cls,
|
||||
query: Q = None,
|
||||
projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
@@ -1131,21 +1296,6 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_many_for_writing(cls, company, *args, **kwargs):
|
||||
result = cls.get_many(
|
||||
company=company,
|
||||
*args,
|
||||
**dict(return_dicts=False, **kwargs),
|
||||
allow_public=True,
|
||||
)
|
||||
forbidden_objects = {obj.id for obj in result if not obj.company}
|
||||
if forbidden_objects:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
@@ -1206,7 +1356,7 @@ class UpdateMixin(object):
|
||||
|
||||
|
||||
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
""" Provide convenience methods for a subclass of mongoengine.Document """
|
||||
"""Provide convenience methods for a subclass of mongoengine.Document"""
|
||||
|
||||
@classmethod
|
||||
def aggregate(
|
||||
@@ -1234,25 +1384,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, set__company="")
|
||||
update: dict = dict(set__company_origin=company_id, set__company="")
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
update: dict = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
if hasattr(cls, "last_change"):
|
||||
update["set__last_change"] = datetime.utcnow()
|
||||
if hasattr(cls, "last_changed_by"):
|
||||
update["set__last_changed_by"] = user_id
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class Model(AttributedDocument):
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
last_update = DateTimeField()
|
||||
last_change = DateTimeField()
|
||||
last_changed_by = StringField()
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
|
||||
@@ -230,11 +230,12 @@ class Task(AttributedDocument):
|
||||
"project",
|
||||
"parent",
|
||||
"hyperparams.*",
|
||||
"execution.queue",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
pattern_fields=("name", "comment", "report"),
|
||||
fields=("execution.queue", "runtime.*", "models.input.model"),
|
||||
fields=("runtime.*", "models.input.model"),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
|
||||
19
apiserver/documentation/api_versions.md
Normal file
19
apiserver/documentation/api_versions.md
Normal file
@@ -0,0 +1,19 @@
|
||||
### Supported api versions
|
||||
|
||||
| Release | ApiVersion |
|
||||
|---------|------------|
|
||||
| v1.13 | 2.27 |
|
||||
| v1.12 | 2.26 |
|
||||
| v1.11 | 2.25 |
|
||||
| v1.10 | 2.24 |
|
||||
| v1.9 | 2.23 |
|
||||
| v1.8 | 2.22 |
|
||||
| v1.7 | 2.21 |
|
||||
| v1.6 | 2.20 |
|
||||
| v1.5 | 2.19 |
|
||||
| v1.4 | 2.18 |
|
||||
| v1.3 | 2.17 |
|
||||
| v1.2 | 2.16 |
|
||||
| v1.1 | 2.15 |
|
||||
| v1.0 | 2.14 |
|
||||
| v0.17 | 2.13 |
|
||||
@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model import EntityVisibility, User
|
||||
from apiserver.database.model.auth import Role, User as AuthUser
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import (
|
||||
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
|
||||
TaskModelNames,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
@@ -66,6 +68,7 @@ class PrePopulate:
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
metadata_filename = "metadata.json"
|
||||
users_filename = "users.json"
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
artifacts_ext = ".artifacts"
|
||||
img_source_regex = re.compile(
|
||||
@@ -78,6 +81,7 @@ class PrePopulate:
|
||||
project_cls: Type[Project]
|
||||
model_cls: Type[Model]
|
||||
user_cls: Type[User]
|
||||
auth_user_cls: Type[AuthUser]
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@classmethod
|
||||
@@ -90,6 +94,8 @@ class PrePopulate:
|
||||
cls.project_cls = cls._get_entity_type("database.model.project.Project")
|
||||
if not hasattr(cls, "user_cls"):
|
||||
cls.user_cls = cls._get_entity_type("database.model.User")
|
||||
if not hasattr(cls, "auth_user_cls"):
|
||||
cls.auth_user_cls = cls._get_entity_type("database.model.auth.User")
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
@@ -205,6 +211,8 @@ class PrePopulate:
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
export_events: bool = True,
|
||||
export_users: bool = False,
|
||||
) -> Sequence[str]:
|
||||
cls._init_entity_types()
|
||||
|
||||
@@ -240,11 +248,15 @@ class PrePopulate:
|
||||
with ZipFile(file, **cls.zip_args) as zfile:
|
||||
if metadata:
|
||||
zfile.writestr(cls.metadata_filename, meta_str)
|
||||
if export_users:
|
||||
cls._export_users(zfile)
|
||||
artifacts = cls._export(
|
||||
zfile,
|
||||
entities=entities,
|
||||
hash_=hash_,
|
||||
tag_entities=tag_exported_entities,
|
||||
export_events=export_events,
|
||||
cleanup_users=not export_users,
|
||||
)
|
||||
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
||||
@@ -265,6 +277,9 @@ class PrePopulate:
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
|
||||
if created_files:
|
||||
print("Created files:\n" + "\n".join(file for file in created_files))
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
@@ -296,18 +311,26 @@ class PrePopulate:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
# Make sure we won't end up with an invalid company ID
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
user_mapping = cls._import_users(zfile, company_id)
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
|
||||
if not existing_user:
|
||||
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
cls._import(
|
||||
zfile,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
user_mapping=user_mapping,
|
||||
)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
|
||||
@@ -438,7 +461,7 @@ class PrePopulate:
|
||||
projects: Sequence[str] = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
|
||||
entities = defaultdict(set)
|
||||
entities: Dict[Any] = defaultdict(set)
|
||||
|
||||
if projects:
|
||||
print("Reading projects...")
|
||||
@@ -497,7 +520,6 @@ class PrePopulate:
|
||||
@classmethod
|
||||
def _cleanup_model(cls, model: Model):
|
||||
model.company = ""
|
||||
model.user = ""
|
||||
model.tags = cls._filter_out_export_tags(model.tags)
|
||||
|
||||
@classmethod
|
||||
@@ -505,7 +527,6 @@ class PrePopulate:
|
||||
task.comment = "Auto generated by Allegro.ai"
|
||||
task.status_message = ""
|
||||
task.status_reason = ""
|
||||
task.user = ""
|
||||
task.company = ""
|
||||
task.tags = cls._filter_out_export_tags(task.tags)
|
||||
if task.output:
|
||||
@@ -513,17 +534,32 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _cleanup_project(cls, project: Project):
|
||||
project.user = ""
|
||||
project.company = ""
|
||||
project.tags = cls._filter_out_export_tags(project.tags)
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity):
|
||||
def _cleanup_auth_user(cls, user: AuthUser):
|
||||
user.company = ""
|
||||
for cred in user.credentials:
|
||||
if getattr(cred, "company", None):
|
||||
cred["company"] = ""
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def _cleanup_be_user(cls, user: User):
|
||||
user.company = ""
|
||||
user.preferences = None
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def _cleanup_entity(cls, entity_cls, entity, cleanup_users):
|
||||
if cleanup_users:
|
||||
entity.user = ""
|
||||
if entity_cls == cls.task_cls:
|
||||
cls._cleanup_task(entity)
|
||||
elif entity_cls == cls.model_cls:
|
||||
cls._cleanup_model(entity)
|
||||
elif entity == cls.project_cls:
|
||||
elif entity_cls == cls.project_cls:
|
||||
cls._cleanup_project(entity)
|
||||
|
||||
@classmethod
|
||||
@@ -633,6 +669,38 @@ class PrePopulate:
|
||||
else:
|
||||
print(f"Artifact {full_path} not found")
|
||||
|
||||
@classmethod
|
||||
def _export_users(cls, writer: ZipFile):
|
||||
auth_users = {
|
||||
user.id: cls._cleanup_auth_user(user)
|
||||
for user in cls.auth_user_cls.objects(role__in=(Role.admin, Role.user))
|
||||
}
|
||||
if not auth_users:
|
||||
return
|
||||
|
||||
be_users = {
|
||||
user.id: cls._cleanup_be_user(user)
|
||||
for user in cls.user_cls.objects(id__in=list(auth_users))
|
||||
}
|
||||
if not be_users:
|
||||
return
|
||||
|
||||
auth_users = {uid: data for uid, data in auth_users.items() if uid in be_users}
|
||||
print(f"Writing {len(auth_users)} users into {writer.filename}")
|
||||
data = {}
|
||||
for field, users in (("auth", auth_users), ("backend", be_users)):
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for user in users.values():
|
||||
w.write(user.to_json())
|
||||
data[field] = f.getvalue()
|
||||
|
||||
def get_field_bytes(k: str, v: bytes) -> bytes:
|
||||
return f'"{k}": '.encode("utf-8") + v
|
||||
|
||||
data_str = b",\n".join(get_field_bytes(k, v) for k, v in data.items())
|
||||
writer.writestr(cls.users_filename, b"{\n" + data_str + b"\n}")
|
||||
|
||||
@classmethod
|
||||
def _get_base_filename(cls, cls_: type):
|
||||
name = f"{cls_.__module__}.{cls_.__name__}"
|
||||
@@ -642,7 +710,13 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
||||
cls,
|
||||
writer: ZipFile,
|
||||
entities: dict,
|
||||
hash_,
|
||||
tag_entities: bool = False,
|
||||
export_events: bool = True,
|
||||
cleanup_users: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
@@ -656,18 +730,19 @@ class PrePopulate:
|
||||
if not items:
|
||||
continue
|
||||
base_filename = cls._get_base_filename(cls_)
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
if export_events:
|
||||
for item in items:
|
||||
artifacts.extend(
|
||||
cls._export_entity_related_data(
|
||||
cls_, item, base_filename, writer, hash_
|
||||
)
|
||||
)
|
||||
)
|
||||
filename = base_filename + ".json"
|
||||
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
|
||||
with BytesIO() as f:
|
||||
with cls.JsonLinesWriter(f) as w:
|
||||
for item in items:
|
||||
cls._cleanup_entity(cls_, item)
|
||||
cls._cleanup_entity(cls_, item, cleanup_users=cleanup_users)
|
||||
w.write(item.to_json())
|
||||
data = f.getvalue()
|
||||
hash_.update(data)
|
||||
@@ -717,7 +792,10 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _generate_new_ids(
|
||||
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
entity_files: Sequence,
|
||||
metadata: Mapping[str, Any],
|
||||
) -> Mapping[str, str]:
|
||||
if not metadata or not any(
|
||||
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
|
||||
@@ -745,6 +823,68 @@ class PrePopulate:
|
||||
)
|
||||
return ids
|
||||
|
||||
@classmethod
|
||||
def _import_users(cls, reader: ZipFile, company_id: str = "") -> dict:
|
||||
"""
|
||||
Import users to db and return the mapping of old user ids to the new ones
|
||||
If no users were in the users file then the mapping was empty
|
||||
If the user in the file has the same email as one of the existing ones then this user is skipped
|
||||
and its id is mapped to the existing user with the same email
|
||||
If the user with the same id exists in backend or auth db then its creation is skipped
|
||||
"""
|
||||
users_file = first(
|
||||
fi for fi in reader.filelist if fi.orig_filename == cls.users_filename
|
||||
)
|
||||
if not users_file:
|
||||
return {}
|
||||
|
||||
existing_user_ids = set(cls.user_cls.objects().scalar("id")) | set(
|
||||
cls.auth_user_cls.objects().scalar("id")
|
||||
)
|
||||
existing_user_emails = {u.email: u.id for u in cls.auth_user_cls.objects()}
|
||||
user_id_mappings = {}
|
||||
|
||||
with reader.open(users_file) as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
auth_users = {u["_id"]: u for u in data["auth"]}
|
||||
be_users = {u["_id"]: u for u in data["backend"]}
|
||||
for uid, user in auth_users.items():
|
||||
email = user.get("email")
|
||||
existing_user_id = existing_user_emails.get(email)
|
||||
if existing_user_id:
|
||||
user_id_mappings[uid] = existing_user_id
|
||||
continue
|
||||
|
||||
user_id_mappings[uid] = uid
|
||||
if uid in existing_user_ids:
|
||||
continue
|
||||
|
||||
credentials = user.get("credentials", [])
|
||||
for c in credentials:
|
||||
if c.get("company") == "":
|
||||
c["company"] = company_id
|
||||
|
||||
if hasattr(cls.auth_user_cls, "sec_groups"):
|
||||
user_role = user.get("role", Role.user)
|
||||
if user_role == Role.user:
|
||||
user["sec_groups"] = ["30795571-a470-4717-a80d-e8705fc776bf"]
|
||||
else:
|
||||
user["sec_groups"] = [
|
||||
"c14a3cc6-1144-4896-8ea6-fb186ee19896",
|
||||
"30795571-a470-4717-a80d-e8705fc776bf",
|
||||
"30795571a4704717a80de8705897ytuyg",
|
||||
]
|
||||
|
||||
auth_user = cls.auth_user_cls.from_json(json.dumps(user), created=True)
|
||||
auth_user.company = company_id
|
||||
auth_user.save()
|
||||
be_user = cls.user_cls.from_json(json.dumps(be_users[uid]), created=True)
|
||||
be_user.company = company_id
|
||||
be_user.save()
|
||||
|
||||
return user_id_mappings
|
||||
|
||||
@classmethod
|
||||
def _import(
|
||||
cls,
|
||||
@@ -753,6 +893,7 @@ class PrePopulate:
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
sort_tasks_by_last_updated: bool = True,
|
||||
user_mapping: Mapping[str, str] = None,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
@@ -763,7 +904,7 @@ class PrePopulate:
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
and fi.orig_filename not in (cls.metadata_filename, cls.users_filename)
|
||||
]
|
||||
metadata = metadata or {}
|
||||
old_to_new_ids = cls._generate_new_ids(reader, entity_files, metadata)
|
||||
@@ -773,7 +914,13 @@ class PrePopulate:
|
||||
full_name = splitext(entity_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
res = cls._import_entity(
|
||||
f, full_name, company_id, user_id, metadata, old_to_new_ids
|
||||
f,
|
||||
full_name=full_name,
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
old_to_new_ids=old_to_new_ids,
|
||||
user_mapping=user_mapping,
|
||||
)
|
||||
if res:
|
||||
tasks = res
|
||||
@@ -794,7 +941,7 @@ class PrePopulate:
|
||||
with reader.open(events_file) as f:
|
||||
full_name = splitext(events_file.orig_filename)[0]
|
||||
print(f"Reading {reader.filename}:{full_name}...")
|
||||
cls._import_events(f, company_id, user_id, task.id)
|
||||
cls._import_events(f, company_id, task.user, task.id)
|
||||
|
||||
@classmethod
|
||||
def _get_entity_type(cls, full_name) -> Type[mongoengine.Document]:
|
||||
@@ -874,7 +1021,7 @@ class PrePopulate:
|
||||
):
|
||||
old_path = old_field.split(".")
|
||||
old_model = nested_get(task_data, old_path)
|
||||
new_models = models.get(type_, [])
|
||||
new_models = [m for m in models.get(type_, []) if m.get("model") is not None]
|
||||
name = TaskModelNames[type_]
|
||||
if old_model and not any(
|
||||
m
|
||||
@@ -908,7 +1055,9 @@ class PrePopulate:
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
old_to_new_ids: Mapping[str, str] = None,
|
||||
user_mapping: Mapping[str, str] = None,
|
||||
) -> Optional[Sequence[Task]]:
|
||||
user_mapping = user_mapping or {}
|
||||
cls_ = cls._get_entity_type(full_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
tasks = []
|
||||
@@ -930,7 +1079,7 @@ class PrePopulate:
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
doc.user = user_id
|
||||
doc.user = user_mapping.get(doc.user, user_id) if doc.user else user_id
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, cls.project_cls):
|
||||
@@ -960,7 +1109,7 @@ class PrePopulate:
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
|
||||
def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str):
|
||||
print(f"Writing events for task {task_id} into database")
|
||||
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
|
||||
events = [json.loads(item) for item in events_chunk]
|
||||
@@ -969,5 +1118,8 @@ class PrePopulate:
|
||||
ev["company_id"] = company_id
|
||||
ev["allow_locked"] = True
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker=""
|
||||
company_id=company_id,
|
||||
identity=Identity(user_id, company=company_id, role=Role.admin),
|
||||
events=events,
|
||||
worker="",
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ elasticsearch==7.17.9
|
||||
fastjsonschema>=2.8
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
flask>=2.3.2
|
||||
flask>=2.3.3
|
||||
furl>=2.0.0
|
||||
google-cloud-storage>=2.8.0
|
||||
gunicorn>=20.1.0
|
||||
@@ -33,4 +33,5 @@ semantic_version>=2.8.3,<3
|
||||
setuptools>=65.5.1
|
||||
six
|
||||
validators>=0.12.4
|
||||
urllib3>=1.26.16
|
||||
urllib3>=1.26.18
|
||||
werkzeug>=3.0.1
|
||||
@@ -1,3 +1,43 @@
|
||||
field_filter {
|
||||
type: object
|
||||
description: Filter on a field that includes combination of 'any' or 'all' included and excluded terms
|
||||
properties {
|
||||
any {
|
||||
type: object
|
||||
description: All the terms in 'any' condition are combined with 'or' operation
|
||||
properties {
|
||||
"include" {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
exclude {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
all {
|
||||
type: object
|
||||
description: All the terms in 'all' condition are combined with 'and' operation
|
||||
properties {
|
||||
"include" {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
exclude {
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
type: string
|
||||
description: The operation between 'any' and 'all' parts of the filter if both are provided
|
||||
default: and
|
||||
enum: [and, or]
|
||||
}
|
||||
}
|
||||
}
|
||||
metadata_item {
|
||||
type: object
|
||||
properties {
|
||||
|
||||
@@ -414,7 +414,7 @@ task {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
|
||||
@@ -11,8 +11,8 @@ _definitions {
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: "training_stats_vector"
|
||||
const: "training_stats_scalar"
|
||||
description: "'training_stats_scalar'"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
@@ -46,8 +46,8 @@ _definitions {
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: "training_stats_vector"
|
||||
const: "training_stats_vector"
|
||||
description: "'training_stats_vector'"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
@@ -82,8 +82,8 @@ _definitions {
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: ""
|
||||
const: "training_debug_image"
|
||||
description: "'training_debug_image'"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
@@ -123,7 +123,7 @@ _definitions {
|
||||
}
|
||||
type {
|
||||
description: "'plot'"
|
||||
const: "plot"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
@@ -221,7 +221,7 @@ _definitions {
|
||||
}
|
||||
type {
|
||||
description: "'log'"
|
||||
const: "log"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
@@ -754,6 +754,42 @@ get_task_metrics{
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_metrics {
|
||||
"2.28" {
|
||||
description: """Get unique metrics and variants from the events of the specified type.
|
||||
Only events reported for the passed task or model ids are analyzed."""
|
||||
request {
|
||||
type: object
|
||||
required: [ tasks ]
|
||||
properties {
|
||||
tasks {
|
||||
description: task ids to get metrics from
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
model_events {
|
||||
description: If not set or set to false then passed ids are task ids otherwise model ids
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
event_type {
|
||||
"description": Event type. If not specified then metrics are collected from the reported events of all types
|
||||
"$ref": "#/definitions/event_type_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_log {
|
||||
"1.5" {
|
||||
description: "Get all 'log' events for this task"
|
||||
@@ -971,10 +1007,17 @@ get_task_events {
|
||||
}
|
||||
}
|
||||
"2.22": ${get_task_events."2.1"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then get retrieving model events. Otherwise task events
|
||||
default: false
|
||||
request.properties {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then get retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1156,6 +1199,13 @@ get_multi_task_plots {
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"2.28": ${get_multi_task_plots."2.26"} {
|
||||
request.properties.metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
@@ -1342,6 +1392,13 @@ multi_task_scalar_metrics_iter_histogram {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.28": ${multi_task_scalar_metrics_iter_histogram."2.22"} {
|
||||
request.properties.metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_single_value_metrics {
|
||||
"2.20" {
|
||||
@@ -1369,6 +1426,13 @@ get_task_single_value_metrics {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.28": ${get_task_single_value_metrics."2.22"} {
|
||||
request.properties.metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_latest_scalar_values {
|
||||
"2.1" {
|
||||
@@ -1470,6 +1534,10 @@ get_scalar_metric_data {
|
||||
type: string
|
||||
description: type of metric
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -1492,7 +1560,7 @@ get_scalar_metric_data {
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,21 +6,12 @@ _default {
|
||||
}
|
||||
|
||||
supported_modes {
|
||||
authorize: false
|
||||
authorize: null
|
||||
"2.9" {
|
||||
description: """ Return supported login modes."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
state {
|
||||
description: "ASCII base64 encoded application state"
|
||||
type: string
|
||||
}
|
||||
callback_url_prefix {
|
||||
description: "URL prefix used to generate the callback URL for each supported SSO provider"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
@@ -59,7 +50,7 @@ supported_modes {
|
||||
description: "SSO authentication providers"
|
||||
type: object
|
||||
additionalProperties {
|
||||
desctiprion: "Provider redirect URL"
|
||||
description: "Provider redirect URL"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
@@ -95,7 +86,7 @@ supported_modes {
|
||||
}
|
||||
|
||||
logout {
|
||||
authorize: false
|
||||
authorize: null
|
||||
allow_roles = [ "*" ]
|
||||
"2.13" {
|
||||
description: """ Logout (including SSO, if used)) """
|
||||
|
||||
@@ -261,6 +261,14 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.27": ${get_all_ex."2.23"} {
|
||||
request.properties {
|
||||
filters {
|
||||
type: object
|
||||
additionalProperties: ${_definitions.field_filter}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -357,9 +365,6 @@ get_all {
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
dependencies {
|
||||
page: [ page_size ]
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
@@ -1086,4 +1091,38 @@ delete_metadata {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update_tags {
|
||||
"2.27" {
|
||||
description: Add or remove tags from multiple models
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
type: array
|
||||
description: IDs of the models to update
|
||||
items {type: string}
|
||||
}
|
||||
add_tags {
|
||||
type: array
|
||||
description: User tags to add
|
||||
items {type: string}
|
||||
}
|
||||
remove_tags {
|
||||
type: array
|
||||
description: User tags to remove
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
type: integer
|
||||
description: The number of updated models
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,7 +203,7 @@ get_entities_count {
|
||||
default: false
|
||||
}
|
||||
active_users {
|
||||
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
description: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ start_pipeline {
|
||||
type: object
|
||||
properties {
|
||||
name: { type: string }
|
||||
value: { type: [string, null] }
|
||||
value: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -79,4 +79,15 @@ start_pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.28": ${start_pipeline."2.17"} {
|
||||
request.properties.verify_watched_queue {
|
||||
description: If passed then check wheter there are any workers watiching the queue
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.queue_watched {
|
||||
description: Returns true if there are workers or autscalers working with the queue
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
@@ -569,7 +570,7 @@ get_all_ex {
|
||||
request {
|
||||
properties {
|
||||
active_users {
|
||||
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
description: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
@@ -660,6 +661,15 @@ get_all_ex {
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
"2.27": ${get_all_ex."2.25"} {
|
||||
request.properties {
|
||||
filters {
|
||||
type: object
|
||||
additionalProperties: ${_definitions.field_filter}
|
||||
}
|
||||
children_tags_filter: ${_definitions.field_filter}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@@ -939,6 +949,13 @@ get_unique_metric_variants {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.28": ${get_unique_metric_variants."2.25"} {
|
||||
request.properties.ids {
|
||||
description: IDs of the tasks or models to get metrics from
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyperparam_values {
|
||||
"2.13" {
|
||||
@@ -1000,6 +1017,12 @@ get_hyperparam_values {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.27": ${get_hyperparam_values."2.26"} {
|
||||
request.properties.pattern {
|
||||
type: string
|
||||
description: The search pattern regex
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.9" {
|
||||
@@ -1270,13 +1293,15 @@ get_task_parents {
|
||||
}
|
||||
project {
|
||||
type: object
|
||||
id {
|
||||
description: "The ID of the parent task project"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "The name of the parent task project"
|
||||
type: string
|
||||
properties {
|
||||
id {
|
||||
description: "The ID of the parent task project"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "The name of the parent task project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +159,14 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.27": ${get_all_ex."2.21"} {
|
||||
request.properties {
|
||||
filters {
|
||||
type: object
|
||||
additionalProperties: ${_definitions.field_filter}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
|
||||
@@ -578,7 +578,7 @@ get_task_data {
|
||||
single_value_metrics {
|
||||
type: object
|
||||
description: If passed then task single value metrics are returned
|
||||
additonalProperties: false
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
response.properties.single_value_metrics {
|
||||
@@ -694,9 +694,6 @@ get_all_ex {
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
dependencies {
|
||||
page: [ page_size ]
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
@@ -720,6 +717,14 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.27": ${get_all_ex."2.26"} {
|
||||
request.properties {
|
||||
filters {
|
||||
type: object
|
||||
additionalProperties: ${_definitions.field_filter}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_tags {
|
||||
"2.23" {
|
||||
|
||||
@@ -190,6 +190,14 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.27": ${get_all_ex."2.23"} {
|
||||
request.properties {
|
||||
filters {
|
||||
type: object
|
||||
additionalProperties: ${_definitions.field_filter}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -289,9 +297,6 @@ get_all {
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
dependencies {
|
||||
page: [ page_size ]
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
@@ -481,7 +486,7 @@ clone {
|
||||
new_task_container {
|
||||
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -659,7 +664,7 @@ create {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -748,7 +753,7 @@ validate {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -910,7 +915,7 @@ edit {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: [string, null] }
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
@@ -2050,3 +2055,37 @@ move {
|
||||
}
|
||||
}
|
||||
}
|
||||
update_tags {
|
||||
"2.27" {
|
||||
description: Add or remove tags from multiple tasks
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
type: array
|
||||
description: IDs of the tasks to update
|
||||
items {type: string}
|
||||
}
|
||||
add_tags {
|
||||
type: array
|
||||
description: User tags to add
|
||||
items {type: string}
|
||||
}
|
||||
remove_tags {
|
||||
type: array
|
||||
description: User tags to remove
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
type: integer
|
||||
description: The number of updated tasks
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -42,7 +42,10 @@ class RequestHandlers:
|
||||
response = redirect(call.result.redirect.url, call.result.redirect.code)
|
||||
else:
|
||||
headers = None
|
||||
disable_cache = False
|
||||
if call.result.filename:
|
||||
# make sure that downloaded files are not cached by the client
|
||||
disable_cache = True
|
||||
try:
|
||||
call.result.filename.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
@@ -61,6 +64,9 @@ class RequestHandlers:
|
||||
status=call.result.code,
|
||||
headers=headers,
|
||||
)
|
||||
if disable_cache:
|
||||
response.cache_control.no_store = True
|
||||
response.cache_control.max_age = 0
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
|
||||
@@ -30,24 +30,35 @@ def get_auth_func(auth_type):
|
||||
raise errors.unauthorized.BadAuthType()
|
||||
|
||||
|
||||
def authorize_token(jwt_token, *_, **__):
|
||||
def authorize_token(jwt_token, service, action, call):
|
||||
"""Validate token against service/endpoint and requests data (dicts).
|
||||
Returns a parsed token object (auth payload)
|
||||
"""
|
||||
call_info = {"ip": call.real_ip}
|
||||
|
||||
def log_error(msg):
|
||||
info = ", ".join(f"{k}={v}" for k, v in call_info.items())
|
||||
log.error(f"{msg} Call info: {info}")
|
||||
|
||||
try:
|
||||
return Token.from_encoded_token(jwt_token)
|
||||
|
||||
except jwt.exceptions.InvalidKeyError as ex:
|
||||
log_error("Failed parsing token.")
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
"jwt invalid key error", reason=ex.args[0]
|
||||
)
|
||||
except jwt.InvalidTokenError as ex:
|
||||
log_error("Failed parsing token.")
|
||||
raise errors.unauthorized.InvalidToken("invalid jwt token", reason=ex.args[0])
|
||||
except ValueError as ex:
|
||||
log.exception("Failed while processing token: %s" % ex.args[0])
|
||||
log_error(f"Failed while processing token: {str(ex.args[0])}.")
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
"failed processing token", reason=ex.args[0]
|
||||
)
|
||||
except Exception:
|
||||
log_error("Failed processing token.")
|
||||
raise
|
||||
|
||||
|
||||
def authorize_credentials(auth_data, service, action, call):
|
||||
|
||||
@@ -90,7 +90,7 @@ class Token(Payload):
|
||||
return token
|
||||
except Exception as e:
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
"failed parsing token, %s" % e.args[0]
|
||||
"failed parsing token", reason=e.args[0]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -39,7 +39,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.26")
|
||||
_max_version = PartialVersion("2.28")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
|
||||
@@ -17,7 +17,7 @@ log = config.logger(__file__)
|
||||
def validate_data(call: APICall, endpoint: Endpoint):
|
||||
""" Perform all required call/endpoint validation, update call result appropriately """
|
||||
try:
|
||||
# todo: remove vaildate_required_fields once all endpoints have json schema
|
||||
# todo: remove validate_required_fields once all endpoints have json schema
|
||||
validate_required_fields(endpoint, call)
|
||||
|
||||
# set models. models will be validated automatically
|
||||
@@ -50,10 +50,17 @@ def validate_role(endpoint, call):
|
||||
pass
|
||||
|
||||
|
||||
def validate_auth(endpoint, call):
|
||||
""" Validate authorization for this endpoint and call.
|
||||
If authentication has occurred, the call is updated with the authentication results.
|
||||
def validate_auth(endpoint: Endpoint, call: "APICall"):
|
||||
"""
|
||||
Validate authorization for this endpoint and call.
|
||||
If authentication has occurred, the call is updated with the authentication results.
|
||||
For the endpoints with authorize==False the validation is not performed to improve performance
|
||||
For the endpoints with authorize==True the validation should pass otherwise exception will be thrown
|
||||
For the endpoints with authorize==None the validation will be tried, but it does not have to succeed
|
||||
"""
|
||||
if endpoint.authorize is not None and not endpoint.authorize:
|
||||
return
|
||||
|
||||
if not call.authorization:
|
||||
# No auth data. Invalid if we need to authorize and valid otherwise
|
||||
if endpoint.authorize:
|
||||
@@ -63,10 +70,9 @@ def validate_auth(endpoint, call):
|
||||
# prepare arguments for validation
|
||||
service, _, action = endpoint.name.partition(".")
|
||||
|
||||
# If we have auth data, we'll try to validate anyway (just so we'll have auth-based permissions whenever possible,
|
||||
# even if endpoint did not require authorization)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
auth = call.authorization or ""
|
||||
auth = call.authorization
|
||||
auth_type, _, auth_data = auth.partition(" ")
|
||||
authorize_func = get_auth_func(auth_type)
|
||||
call.auth = authorize_func(auth_data, service, action, call)
|
||||
@@ -78,7 +84,7 @@ def validate_auth(endpoint, call):
|
||||
|
||||
def validate_impersonation(endpoint, call):
|
||||
""" Validate impersonation headers and set impersonated identity and authorization data accordingly.
|
||||
:returns True if impersonating, False otherwise
|
||||
:return: True if impersonating, False otherwise
|
||||
"""
|
||||
try:
|
||||
act_as = call.act_as
|
||||
|
||||
@@ -31,6 +31,7 @@ from apiserver.apimodels.events import (
|
||||
GetMetricSamplesRequest,
|
||||
TaskMetric,
|
||||
MultiTaskPlotsRequest,
|
||||
MultiTaskMetricsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||
@@ -38,6 +39,7 @@ from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@@ -71,7 +73,12 @@ def _assert_task_or_model_exists(
|
||||
@endpoint("events.add")
|
||||
def add(call: APICall, company_id, _):
|
||||
data = call.data.copy()
|
||||
added, err_count, err_info = event_bll.add_events(company_id, [data], call.worker)
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
events=[data],
|
||||
worker=call.worker,
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@@ -82,9 +89,10 @@ def add_batch(call: APICall, company_id, _):
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id,
|
||||
events,
|
||||
call.worker,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
events=events,
|
||||
worker=call.worker,
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
@@ -360,7 +368,7 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
task_ids=[task_id],
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
@@ -515,6 +523,7 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
),
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -542,7 +551,8 @@ def get_task_single_value_metrics(
|
||||
tasks=_get_single_value_metrics_response(
|
||||
companies=companies,
|
||||
value_metrics=event_bll.metrics.get_task_single_value_metrics(
|
||||
companies=companies
|
||||
companies=companies,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -558,8 +568,8 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
list(companies),
|
||||
task_ids,
|
||||
company_id=list(companies),
|
||||
task_id=task_ids,
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
@@ -585,10 +595,11 @@ def _get_multitask_plots(
|
||||
companies: TaskCompanies,
|
||||
last_iters: int,
|
||||
last_iters_per_task_metric: bool,
|
||||
metrics: MetricVariants = None,
|
||||
request_metrics: Sequence[ApiMetrics] = None,
|
||||
scroll_id=None,
|
||||
no_scroll=True,
|
||||
) -> Tuple[dict, int, str]:
|
||||
metrics = _get_metric_variants_from_request(request_metrics)
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
@@ -623,6 +634,7 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
||||
scroll_id=request.scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
last_iters_per_task_metric=request.last_iters_per_task_metric,
|
||||
request_metrics=request.metrics,
|
||||
)
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
@@ -954,12 +966,38 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_metrics")
|
||||
def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest):
|
||||
companies = _get_task_or_model_index_companies(
|
||||
company_id, request.tasks, model_events=request.model_events
|
||||
)
|
||||
if not companies:
|
||||
return {"metrics": []}
|
||||
|
||||
metrics = event_bll.metrics.get_multi_task_metrics(
|
||||
companies=companies,
|
||||
event_type=request.event_type
|
||||
)
|
||||
res = [
|
||||
{
|
||||
"metric": m,
|
||||
"variants": sorted(vars_),
|
||||
}
|
||||
for m, vars_ in metrics.items()
|
||||
]
|
||||
call.result.data = {
|
||||
"metrics": sorted(res, key=itemgetter("metric"))
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, return_tasks=False)
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
|
||||
)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(
|
||||
company_id, task_id, allow_locked=allow_locked
|
||||
@@ -984,7 +1022,9 @@ def delete_for_model(call: APICall, company_id: str, _):
|
||||
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
|
||||
task_id = request.task
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, return_tasks=False)
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
|
||||
)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.clear_task_log(
|
||||
company_id=company_id,
|
||||
@@ -1085,7 +1125,7 @@ def scalar_metrics_iter_raw(
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
task_ids=[task_id],
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,11 +22,13 @@ from apiserver.apimodels.models import (
|
||||
ModelsDeleteManyRequest,
|
||||
ModelsGetRequest,
|
||||
)
|
||||
from apiserver.apimodels.tasks import UpdateTagsRequest
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import validate_id
|
||||
@@ -45,6 +47,7 @@ from apiserver.database.utils import (
|
||||
filter_fields,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
@@ -76,7 +79,9 @@ def get_by_id(call: APICall, company_id, _):
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=company_id,
|
||||
)
|
||||
conform_model_data(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
@@ -102,7 +107,9 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_model_data(call, model_dict)
|
||||
@@ -128,7 +135,10 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
return
|
||||
|
||||
model_ids = {model["id"] for model in models}
|
||||
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
|
||||
stats = ModelBLL.get_model_stats(
|
||||
company=company_id,
|
||||
model_ids=list(model_ids),
|
||||
)
|
||||
|
||||
for model in models:
|
||||
model["stats"] = stats.get(model["id"])
|
||||
@@ -183,7 +193,7 @@ create_fields = {
|
||||
"project": Project,
|
||||
"parent": Model,
|
||||
"framework": None,
|
||||
"design": None,
|
||||
"design": dict,
|
||||
"labels": dict,
|
||||
"ready": None,
|
||||
"metadata": list,
|
||||
@@ -212,7 +222,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Model,
|
||||
project=project,
|
||||
projects=[project],
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
@@ -220,7 +230,9 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Model, projects=projects,
|
||||
company,
|
||||
Tags.Model,
|
||||
projects=projects,
|
||||
)
|
||||
|
||||
|
||||
@@ -239,13 +251,12 @@ def update_for_task(call: APICall, company_id, _):
|
||||
)
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("models", "execution", "name", "status", "project"),
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
@@ -283,6 +294,8 @@ def update_for_task(call: APICall, company_id, _):
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
@@ -301,6 +314,7 @@ def update_for_task(call: APICall, company_id, _):
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
@@ -320,7 +334,6 @@ def update_for_task(call: APICall, company_id, _):
|
||||
response_data_model=CreateModelResponse,
|
||||
)
|
||||
def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
|
||||
if req_model.public:
|
||||
company_id = ""
|
||||
|
||||
@@ -331,7 +344,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
validate_task(company_id, call.identity, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
@@ -345,6 +358,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
@@ -359,7 +374,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(company_id, fields)
|
||||
validate_task(company_id, call.identity, fields)
|
||||
|
||||
if "labels" in fields:
|
||||
labels = fields["labels"]
|
||||
@@ -389,8 +404,11 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
return fields
|
||||
|
||||
|
||||
def validate_task(company_id, fields: dict):
|
||||
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
|
||||
def validate_task(company_id: str, identity: Identity, fields: dict):
|
||||
task_id = fields["task"]
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
@@ -414,12 +432,20 @@ def edit(call: APICall, company_id, _):
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
now = datetime.utcnow()
|
||||
fields.update(
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
fields.update(last_update=now)
|
||||
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
@@ -445,13 +471,25 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
now = datetime.utcnow()
|
||||
updated_count, updated_fields = Model.safe_update(
|
||||
company_id,
|
||||
model.id,
|
||||
data,
|
||||
injected_update=dict(
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
),
|
||||
)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
model.update(upsert=False, last_update=now)
|
||||
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
@@ -480,7 +518,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
|
||||
updated, published_task = ModelBLL.publish_model(
|
||||
model_id=request.model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
)
|
||||
@@ -499,7 +537,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
|
||||
func=partial(
|
||||
ModelBLL.publish_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
),
|
||||
@@ -573,7 +611,10 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
|
||||
)
|
||||
def archive_many(call: APICall, company_id, request: BatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
|
||||
func=partial(
|
||||
ModelBLL.archive_model, company_id=company_id, user_id=call.identity.user
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
|
||||
@@ -588,7 +629,8 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
|
||||
)
|
||||
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
|
||||
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
|
||||
ids=request.ids,
|
||||
)
|
||||
call.result.data_model = BatchResponse(
|
||||
succeeded=[
|
||||
@@ -601,7 +643,11 @@ def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
||||
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
call.result.data = Model.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
ids=request.ids,
|
||||
invalid_cls=InvalidModelId,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -610,7 +656,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
call.result.data = Model.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
ids=request.ids,
|
||||
invalid_cls=InvalidModelId,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -633,30 +683,51 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.update_tags")
|
||||
def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
return {
|
||||
"updated": org_bll.edit_entity_tags(
|
||||
company_id=company_id,
|
||||
entity_cls=Model,
|
||||
entity_ids=request.ids,
|
||||
add_tags=request.add_tags,
|
||||
remove_tags=request.remove_tags,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
model_id = request.model
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
now = datetime.utcnow()
|
||||
return {
|
||||
"updated": Metadata.edit_metadata(
|
||||
model,
|
||||
items=request.metadata,
|
||||
replace_metadata=request.replace_metadata,
|
||||
last_update=datetime.utcnow(),
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
model_id = request.model
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
return {
|
||||
"updated": Metadata.delete_metadata(
|
||||
model, keys=request.keys, last_update=datetime.utcnow()
|
||||
model,
|
||||
keys=request.keys,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,22 +5,24 @@ import attr
|
||||
|
||||
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
|
||||
from apiserver.apimodels.pipelines import (
|
||||
StartPipelineResponse,
|
||||
StartPipelineRequest,
|
||||
DeleteRunsRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import enqueue_task, delete_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskType
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def _update_task_name(task: Task):
|
||||
@@ -57,7 +59,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=False,
|
||||
force=True,
|
||||
return_file_urls=False,
|
||||
@@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
|
||||
call.result.data = dict(succeeded=succeeded, failed=failures)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
|
||||
)
|
||||
@endpoint("pipelines.start_pipeline")
|
||||
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
|
||||
hyperparams = None
|
||||
if request.args:
|
||||
@@ -108,10 +108,19 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
)
|
||||
extra = {}
|
||||
if request.verify_watched_queue and queued:
|
||||
res_queue = nested_get(res, ("fields", "execution.queue"))
|
||||
if res_queue:
|
||||
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
|
||||
|
||||
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
|
||||
call.result.data = dict(
|
||||
pipeline=task.id,
|
||||
enqueued=bool(queued),
|
||||
**extra,
|
||||
)
|
||||
|
||||
@@ -108,7 +108,13 @@ def _get_project_stats_filter(
|
||||
if request.include_stats_filter or not request.children_type:
|
||||
return request.include_stats_filter, request.search_hidden
|
||||
|
||||
stats_filter = {"tags": request.children_tags} if request.children_tags else {}
|
||||
if request.children_tags_filter:
|
||||
stats_filter = {"tags": request.children_tags_filter}
|
||||
elif request.children_tags:
|
||||
stats_filter = {"tags": request.children_tags}
|
||||
else:
|
||||
stats_filter = {}
|
||||
|
||||
if request.children_type == ProjectChildrenType.pipeline:
|
||||
return (
|
||||
{
|
||||
@@ -153,6 +159,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
allow_public=allow_public,
|
||||
children_type=request.children_type,
|
||||
children_tags=request.children_tags,
|
||||
children_tags_filter=request.children_tags_filter,
|
||||
)
|
||||
if not ids:
|
||||
return {"projects": []}
|
||||
@@ -373,6 +380,7 @@ def get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
ids=request.ids,
|
||||
model_metrics=request.model_metrics,
|
||||
)
|
||||
|
||||
@@ -452,6 +460,7 @@ def get_hyperparam_values(
|
||||
name=request.name,
|
||||
include_subprojects=request.include_subprojects,
|
||||
allow_public=request.allow_public,
|
||||
pattern=request.pattern,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
@@ -505,7 +514,11 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
call.result.data = Project.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
ids=request.ids,
|
||||
invalid_cls=InvalidProjectId,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -514,7 +527,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
call.result.data = Project.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
ids=request.ids,
|
||||
invalid_cls=InvalidProjectId,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ from apiserver.apimodels.reports import (
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.models import conform_model_data
|
||||
from apiserver.services.utils import process_include_subprojects, sort_tags_response
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
@@ -57,15 +59,15 @@ update_fields = {
|
||||
}
|
||||
|
||||
|
||||
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
|
||||
def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None):
|
||||
if only_fields and "type" not in only_fields:
|
||||
only_fields += ("type",)
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=only_fields,
|
||||
requires_write_access=requires_write_access,
|
||||
)
|
||||
if task.type != TaskType.report:
|
||||
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
|
||||
@@ -78,6 +80,7 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
|
||||
task = _assert_report(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only_fields=("status",),
|
||||
)
|
||||
|
||||
@@ -265,7 +268,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
res["plots"] = _get_multitask_plots(
|
||||
companies=companies,
|
||||
last_iters=request.plots.iters,
|
||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||
request_metrics=request.plots.metrics,
|
||||
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
|
||||
)[0]
|
||||
|
||||
@@ -302,6 +305,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
|
||||
task = _assert_report(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
identity=call.identity,
|
||||
only_fields=("project",),
|
||||
)
|
||||
user_id = call.identity.user
|
||||
@@ -337,7 +341,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
|
||||
response_data_model=UpdateResponse,
|
||||
)
|
||||
def publish(call: APICall, company_id, request: PublishReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, identity=call.identity
|
||||
)
|
||||
updates = ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
@@ -352,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
|
||||
|
||||
@endpoint("reports.archive")
|
||||
def archive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, identity=call.identity
|
||||
)
|
||||
archived = task.update(
|
||||
status_message=request.message,
|
||||
status_reason="",
|
||||
@@ -366,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
|
||||
@endpoint("reports.unarchive")
|
||||
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
|
||||
task = _assert_report(company_id=company_id, task_id=request.task)
|
||||
task = _assert_report(
|
||||
company_id=company_id, task_id=request.task, identity=call.identity
|
||||
)
|
||||
unarchived = task.update(
|
||||
status_message=request.message,
|
||||
status_reason="",
|
||||
@@ -394,6 +404,7 @@ def delete(call: APICall, company_id, request: DeleteReportRequest):
|
||||
task = _assert_report(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
identity=call.identity,
|
||||
only_fields=("project",),
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -67,6 +67,7 @@ from apiserver.apimodels.tasks import (
|
||||
GetAllReq,
|
||||
DequeueRequest,
|
||||
DequeueManyRequest,
|
||||
UpdateTagsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.model import ModelBLL
|
||||
@@ -76,7 +77,6 @@ from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
)
|
||||
from apiserver.bll.task.artifacts import (
|
||||
artifacts_prepare_for_save,
|
||||
@@ -100,10 +100,17 @@ from apiserver.bll.task.task_operations import (
|
||||
unarchive_task,
|
||||
move_tasks_to_trash,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.bll.task.utils import (
|
||||
update_task,
|
||||
get_task_for_update,
|
||||
deleted_prefix,
|
||||
get_many_tasks_for_writing,
|
||||
get_task_with_write_access,
|
||||
)
|
||||
from apiserver.bll.util import run_batch_operation, update_project_time
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
@@ -112,8 +119,13 @@ from apiserver.database.model.task.task import (
|
||||
ModelItem,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
||||
from apiserver.database.utils import (
|
||||
get_fields_attr,
|
||||
parse_from_call,
|
||||
get_options,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
@@ -138,14 +150,34 @@ org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
def _assert_writable_tasks(
|
||||
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
|
||||
) -> Sequence[Task]:
|
||||
tasks = get_many_tasks_for_writing(
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
query=Q(id__in=ids),
|
||||
only=only,
|
||||
)
|
||||
missing_ids = set(ids) - {t.id for t in tasks}
|
||||
if missing_ids:
|
||||
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def set_task_status_from_call(
|
||||
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
|
||||
request: UpdateRequest,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
new_status=None,
|
||||
**set_fields,
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
request.task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id", "status", "project"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
status_reason = request.status_reason
|
||||
@@ -157,15 +189,17 @@ def set_task_status_from_call(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
).execute(**set_fields)
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
|
||||
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
req_model.task, company_id=company_id, allow_public=True
|
||||
)
|
||||
def get_by_id(call: APICall, company_id, request: TaskRequest):
|
||||
task = TaskBLL.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.task,
|
||||
allow_public=True,
|
||||
)[0]
|
||||
task_dict = task.to_proper_dict()
|
||||
conform_task_data(call, task_dict)
|
||||
call.result.data = {"task": task_dict}
|
||||
@@ -223,7 +257,9 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
conform_task_data(call, tasks)
|
||||
@@ -274,7 +310,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**stop_task(
|
||||
task_id=req_model.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=req_model.status_reason,
|
||||
force=req_model.force,
|
||||
@@ -292,7 +328,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
|
||||
func=partial(
|
||||
stop_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
@@ -315,7 +351,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.stopped,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -332,7 +368,7 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.in_progress,
|
||||
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
|
||||
)
|
||||
@@ -349,7 +385,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.failed,
|
||||
)
|
||||
)
|
||||
@@ -363,7 +399,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.closed,
|
||||
)
|
||||
)
|
||||
@@ -377,18 +413,19 @@ create_fields = {
|
||||
"error": None,
|
||||
"comment": None,
|
||||
"parent": Task,
|
||||
"project": None,
|
||||
"project": Project,
|
||||
"input": None,
|
||||
"models": None,
|
||||
"container": None,
|
||||
"container": dict,
|
||||
"output_dest": None,
|
||||
"execution": None,
|
||||
"hyperparams": None,
|
||||
"configuration": None,
|
||||
"hyperparams": dict,
|
||||
"configuration": dict,
|
||||
"script": None,
|
||||
"runtime": None,
|
||||
"runtime": dict,
|
||||
}
|
||||
|
||||
|
||||
dict_fields_paths = [("execution", "model_labels"), "container"]
|
||||
|
||||
|
||||
@@ -429,13 +466,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
||||
|
||||
for data in tasks_data:
|
||||
params_unprepare_from_saved(
|
||||
fields=data, copy_to_legacy=need_legacy_params,
|
||||
fields=data,
|
||||
copy_to_legacy=need_legacy_params,
|
||||
)
|
||||
artifacts_unprepare_from_saved(fields=data)
|
||||
|
||||
|
||||
def prepare_create_fields(
|
||||
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
|
||||
call: APICall,
|
||||
valid_fields=None,
|
||||
output=None,
|
||||
previous_task: Task = None,
|
||||
):
|
||||
valid_fields = valid_fields if valid_fields is not None else create_fields
|
||||
t_fields = task_fields
|
||||
@@ -493,7 +534,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Task,
|
||||
project=project,
|
||||
projects=[project],
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
@@ -562,11 +603,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
task_id = req_model.task
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(
|
||||
id=task_id, company=company_id, _only=["id", "project"]
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("id", "project"),
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
|
||||
|
||||
@@ -578,7 +620,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
id=task_id,
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
),
|
||||
)
|
||||
if updated_count:
|
||||
@@ -602,11 +645,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
|
||||
requirements = req_model.requirements
|
||||
with translate_errors_context():
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
req_model.task,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("status", "script"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
if not task.script:
|
||||
raise errors.bad_request.MissingTaskFields(
|
||||
@@ -632,8 +675,11 @@ def update_batch(call: APICall, company_id, _):
|
||||
items = {i["task"]: i for i in items}
|
||||
tasks = {
|
||||
t.id: t
|
||||
for t in Task.get_many_for_writing(
|
||||
company=company_id, query=Q(id__in=list(items))
|
||||
for t in _assert_writable_tasks(
|
||||
identity=call.identity,
|
||||
company_id=company_id,
|
||||
ids=list(items),
|
||||
only=("id", "project"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -652,7 +698,8 @@ def update_batch(call: APICall, company_id, _):
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
partial_update_dict.update(
|
||||
last_change=now, last_changed_by=call.identity.user,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
@@ -686,9 +733,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
force = req_model.force
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(id=task_id, company=company_id)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
)
|
||||
|
||||
if not force and task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
@@ -752,7 +801,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
"tasks.get_hyper_params",
|
||||
request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
@@ -767,7 +817,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
@@ -781,7 +831,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
@@ -790,7 +840,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
"tasks.get_configurations",
|
||||
request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
@@ -805,7 +856,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
|
||||
"tasks.get_configuration_names",
|
||||
request_data_model=GetConfigurationNamesRequest,
|
||||
)
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
@@ -826,7 +878,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
@@ -842,7 +894,7 @@ def delete_configuration(
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
@@ -859,7 +911,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queued, res = enqueue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -884,7 +936,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
func=partial(
|
||||
enqueue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -911,13 +963,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue", response_data_model=DequeueResponse,
|
||||
"tasks.dequeue",
|
||||
response_data_model=DequeueResponse,
|
||||
)
|
||||
def dequeue(call: APICall, company_id, request: DequeueRequest):
|
||||
dequeued, res = dequeue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
remove_from_all_queues=request.remove_from_all_queues,
|
||||
@@ -927,14 +980,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue_many", response_data_model=DequeueManyResponse,
|
||||
"tasks.dequeue_many",
|
||||
response_data_model=DequeueManyResponse,
|
||||
)
|
||||
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
dequeue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
remove_from_all_queues=request.remove_from_all_queues,
|
||||
@@ -958,7 +1012,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
dequeued, cleanup_res, updates = reset_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
@@ -986,7 +1040,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
func=partial(
|
||||
reset_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
@@ -1023,9 +1077,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
response_data_model=ArchiveResponse,
|
||||
)
|
||||
def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
tasks = TaskBLL.assert_exists(
|
||||
archived = 0
|
||||
tasks = _assert_writable_tasks(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
call.identity,
|
||||
ids=request.tasks,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
@@ -1036,11 +1092,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
"enqueue_status",
|
||||
),
|
||||
)
|
||||
archived = 0
|
||||
for task in tasks:
|
||||
archived += archive_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task=task,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1059,7 +1114,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
archive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1081,7 +1136,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
unarchive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1100,7 +1155,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
deleted, task, cleanup_res = delete_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
@@ -1122,7 +1177,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
@@ -1160,7 +1215,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
|
||||
updates = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1179,7 +1234,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
|
||||
func=partial(
|
||||
publish_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model
|
||||
if request.publish_model
|
||||
@@ -1207,7 +1262,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
**set_task_status_from_call(
|
||||
request,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.completed,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -1217,7 +1272,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
publish_res = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1232,9 +1287,12 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
|
||||
|
||||
@endpoint("tasks.ping", request_data_model=PingRequest)
|
||||
def ping(_, company_id, request: PingRequest):
|
||||
def ping(call: APICall, company_id, request: PingRequest):
|
||||
TaskBLL.set_last_update(
|
||||
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
|
||||
task_ids=[request.task],
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1249,7 +1307,7 @@ def add_or_update_artifacts(
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=True,
|
||||
@@ -1266,7 +1324,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=True,
|
||||
@@ -1277,14 +1335,22 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
ids=request.ids,
|
||||
invalid_cls=InvalidTaskId,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
ids=request.ids,
|
||||
invalid_cls=InvalidTaskId,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1295,6 +1361,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
"project or project_name is required"
|
||||
)
|
||||
|
||||
_assert_writable_tasks(company_id, call.identity, request.ids)
|
||||
updated_projects = set(
|
||||
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
|
||||
)
|
||||
@@ -1314,9 +1381,25 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
return {"project_id": project_id}
|
||||
|
||||
|
||||
@endpoint("tasks.update_tags")
|
||||
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
|
||||
_assert_writable_tasks(company_id, call.identity, request.ids)
|
||||
return {
|
||||
"updated": org_bll.edit_entity_tags(
|
||||
company_id=company_id,
|
||||
entity_cls=Task,
|
||||
entity_ids=request.ids,
|
||||
add_tags=request.add_tags,
|
||||
remove_tags=request.remove_tags,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
get_task_for_update(
|
||||
company_id=company_id, task_id=request.task, force=True, identity=call.identity
|
||||
)
|
||||
|
||||
models_field = f"models__{request.type}"
|
||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||
@@ -1326,6 +1409,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
|
||||
updated = TaskBLL.update_statistics(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
last_iteration_max=request.iteration,
|
||||
**({f"push__{models_field}": model} if not updated else {}),
|
||||
)
|
||||
@@ -1335,7 +1419,9 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=request.task, force=True, identity=call.identity
|
||||
)
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
@@ -1348,6 +1434,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
}
|
||||
|
||||
updated = task.update(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
**commands,
|
||||
)
|
||||
return {"updated": updated}
|
||||
|
||||
91
apiserver/tests/automated/test_get_all_ex_filters.py
Normal file
91
apiserver/tests/automated/test_get_all_ex_filters.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestGetAllExFilters(TestService):
|
||||
def test_no_tags_filter(self):
|
||||
task = self._temp_task(tags=["test"])
|
||||
task_no_tags = self._temp_task()
|
||||
tasks = [task, task_no_tags]
|
||||
|
||||
for cond, op, tags, expected_tasks in (
|
||||
("any", "include", [None], [task_no_tags]),
|
||||
("any", "include", ["test"], [task]),
|
||||
("any", "include", ["test", None], [task, task_no_tags]),
|
||||
("any", "exclude", [None], [task]),
|
||||
("any", "exclude", ["test"], [task_no_tags]),
|
||||
("any", "exclude", ["test", None], [task, task_no_tags]),
|
||||
("all", "include", [None], [task_no_tags]),
|
||||
("all", "include", ["test"], [task]),
|
||||
("all", "include", ["test", None], []),
|
||||
("all", "exclude", [None], [task]),
|
||||
("all", "exclude", ["test"], [task_no_tags]),
|
||||
("all", "exclude", ["test", None], []),
|
||||
):
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks, filters={"tags": {cond: {op: tags}}}
|
||||
).tasks
|
||||
self.assertEqual({t.id for t in res}, set(expected_tasks))
|
||||
|
||||
def test_list_filters(self):
|
||||
tags = ["a", "b", "c", "d"]
|
||||
tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)]
|
||||
|
||||
# invalid params check
|
||||
with self.api.raises(errors.bad_request.ValidationError):
|
||||
self.api.tasks.get_all_ex(filters={"tags": {"test": ["1"]}})
|
||||
|
||||
# test any condition
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks, filters={"tags": {"any": {"include": ["a", "b"]}}}
|
||||
).tasks
|
||||
self.assertEqual(set(tasks[1:]), set(t.id for t in res))
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks, filters={"tags": {"any": {"exclude": ["c", "d"]}}}
|
||||
).tasks
|
||||
self.assertEqual(set(tasks[:-1]), set(t.id for t in res))
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks,
|
||||
filters={"tags": {"any": {"include": ["a", "b"], "exclude": ["c", "d"]}}},
|
||||
).tasks
|
||||
self.assertEqual(set(tasks), set(t.id for t in res))
|
||||
|
||||
# test all condition
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks, filters={"tags": {"all": {"include": ["a", "b"]}}}
|
||||
).tasks
|
||||
self.assertEqual(set(tasks[2:]), set(t.id for t in res))
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks, filters={"tags": {"all": {"exclude": ["c", "d"]}}}
|
||||
).tasks
|
||||
self.assertEqual(set(tasks[:-2]), set(t.id for t in res))
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks,
|
||||
filters={"tags": {"all": {"include": ["a", "b"], "exclude": ["c", "d"]}}},
|
||||
).tasks
|
||||
self.assertEqual([tasks[2]], [t.id for t in res])
|
||||
|
||||
# test combination
|
||||
res = self.api.tasks.get_all_ex(
|
||||
id=tasks,
|
||||
filters={
|
||||
"tags": {"any": {"include": ["a", "b"]}, "all": {"exclude": ["c", "d"]}}
|
||||
},
|
||||
).tasks
|
||||
self.assertEqual(set(tasks[1:-2]), set(t.id for t in res))
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="test get_all_ex filters",
|
||||
type="training",
|
||||
)
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
**kwargs,
|
||||
delete_paramse=dict(can_fail=True, force=True),
|
||||
)
|
||||
@@ -37,29 +37,44 @@ class TestPipelines(TestService):
|
||||
|
||||
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
|
||||
pipeline_task = res.pipeline
|
||||
try:
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
finally:
|
||||
self.api.tasks.delete(task=pipeline_task, force=True)
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
|
||||
# watched queue
|
||||
queue = self._temp_queue("test pipelines")
|
||||
project, task = self._temp_project_and_task(name="pipelines test1")
|
||||
res = self.api.pipelines.start_pipeline(
|
||||
task=task, queue=queue, verify_watched_queue=True
|
||||
)
|
||||
self.assertEqual(res.queue_watched, False)
|
||||
|
||||
self.api.workers.register(worker="test pipelines", queues=[queue])
|
||||
project, task = self._temp_project_and_task(name="pipelines test2")
|
||||
res = self.api.pipelines.start_pipeline(
|
||||
task=task, queue=queue, verify_watched_queue=True
|
||||
)
|
||||
self.assertEqual(res.queue_watched, True)
|
||||
|
||||
def _temp_project_and_task(self, name) -> Tuple[str, str]:
|
||||
project = self.create_temp(
|
||||
"projects", name=name, description="test", delete_params=dict(force=True),
|
||||
"projects",
|
||||
name=name,
|
||||
description="test",
|
||||
delete_params=dict(force=True, delete_contents=True),
|
||||
)
|
||||
|
||||
return (
|
||||
@@ -72,3 +87,6 @@ class TestPipelines(TestService):
|
||||
system_tags=["pipeline"],
|
||||
),
|
||||
)
|
||||
|
||||
def _temp_queue(self, queue_name, **kwargs):
|
||||
return self.create_temp("queues", name=queue_name, **kwargs)
|
||||
|
||||
@@ -92,6 +92,29 @@ class TestProjectTags(TestService):
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
|
||||
def test_tags_api(self):
|
||||
p = self.create_temp("projects", name="Test tags api", description="test")
|
||||
|
||||
# task
|
||||
initial_tags = ["Task tag"]
|
||||
task = self.new_task(project=p, tags=initial_tags)
|
||||
data = self.api.projects.get_task_tags(projects=[p])
|
||||
self.assertEqual(data.tags, initial_tags)
|
||||
new_tags = ["New task tag"]
|
||||
self.api.tasks.update_tags(ids=[task], add_tags=new_tags, remove_tags=initial_tags)
|
||||
data = self.api.projects.get_task_tags(projects=[p])
|
||||
self.assertEqual(data.tags, new_tags)
|
||||
|
||||
# model
|
||||
initial_tags = ["Model tag"]
|
||||
model = self.new_model(project=p, tags=initial_tags)
|
||||
data = self.api.projects.get_model_tags(projects=[p])
|
||||
self.assertEqual(data.tags, initial_tags)
|
||||
new_tags = ["New model tag"]
|
||||
self.api.models.update_tags(ids=[model], add_tags=new_tags)
|
||||
data = self.api.projects.get_model_tags(projects=[p])
|
||||
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test project tags"
|
||||
|
||||
@@ -64,6 +64,20 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(p.basename, "project2")
|
||||
self.assertEqual(p.stats.active.total_tasks, 2)
|
||||
|
||||
# new filter
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root],
|
||||
children_type="report",
|
||||
children_tags_filter={"any": {"include": ["test1", "test2"]}},
|
||||
shallow_search=True,
|
||||
include_stats=True,
|
||||
check_own_contents=True,
|
||||
).projects
|
||||
self.assertEqual(len(projects), 1)
|
||||
p = projects[0]
|
||||
self.assertEqual(p.basename, "project2")
|
||||
self.assertEqual(p.stats.active.total_tasks, 2)
|
||||
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root],
|
||||
children_type="report",
|
||||
@@ -77,6 +91,20 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(p.basename, "project2")
|
||||
self.assertEqual(p.stats.active.total_tasks, 1)
|
||||
|
||||
# new filter
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root],
|
||||
children_type="report",
|
||||
children_tags_filter={"all": {"include": ["test1", "test2"]}},
|
||||
shallow_search=True,
|
||||
include_stats=True,
|
||||
check_own_contents=True,
|
||||
).projects
|
||||
self.assertEqual(len(projects), 1)
|
||||
p = projects[0]
|
||||
self.assertEqual(p.basename, "project2")
|
||||
self.assertEqual(p.stats.active.total_tasks, 1)
|
||||
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root],
|
||||
children_type="report",
|
||||
@@ -102,6 +130,20 @@ class TestSubProjects(TestService):
|
||||
for p in projects:
|
||||
self.assertEqual(p.stats.active.total_tasks, 1)
|
||||
|
||||
# new filter
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root],
|
||||
children_type="report",
|
||||
children_tags_filter={"all": {"exclude": ["test1", "test2"]}},
|
||||
shallow_search=True,
|
||||
include_stats=True,
|
||||
check_own_contents=True,
|
||||
).projects
|
||||
self.assertEqual(len(projects), 1)
|
||||
p = projects[0]
|
||||
self.assertEqual(p.basename, "project1")
|
||||
self.assertEqual(p.stats.active.total_tasks, 1)
|
||||
|
||||
def test_query_children(self):
|
||||
test_root_name = "TestQueryChildren"
|
||||
test_root = self._temp_project(name=test_root_name)
|
||||
|
||||
@@ -16,10 +16,18 @@ class TestTaskEvents(TestService):
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
default_task_name = "test task events"
|
||||
|
||||
def _temp_task(self, name=default_task_name):
|
||||
task_input = dict(name=name, type="training",)
|
||||
def _temp_project(self, name=default_task_name):
|
||||
return self.create_temp(
|
||||
"tasks", delete_paramse=self.delete_params, **task_input
|
||||
"projects",
|
||||
name=name,
|
||||
description="test",
|
||||
delete_params=self.delete_params,
|
||||
)
|
||||
|
||||
def _temp_task(self, name=default_task_name, **kwargs):
|
||||
self.update_missing(kwargs, name=name, type="training")
|
||||
return self.create_temp(
|
||||
"tasks", delete_paramse=self.delete_params, **kwargs
|
||||
)
|
||||
|
||||
def _temp_model(self, name="test model events", **kwargs):
|
||||
@@ -62,6 +70,26 @@ class TestTaskEvents(TestService):
|
||||
self._assert_task_metrics(tasks, "log")
|
||||
self._assert_task_metrics(tasks, "training_stats_scalar")
|
||||
|
||||
self._assert_multitask_metrics(
|
||||
tasks=list(tasks), metrics=["Metric1", "Metric2", "Metric3"]
|
||||
)
|
||||
self._assert_multitask_metrics(
|
||||
tasks=list(tasks),
|
||||
event_type="training_debug_image",
|
||||
metrics=["Metric1", "Metric2", "Metric3"],
|
||||
)
|
||||
self._assert_multitask_metrics(tasks=list(tasks), event_type="plot", metrics=[])
|
||||
|
||||
def _assert_multitask_metrics(
|
||||
self, tasks: Sequence[str], metrics: Sequence[str], event_type: str = None
|
||||
):
|
||||
res = self.api.events.get_multi_task_metrics(
|
||||
tasks=tasks,
|
||||
**({"event_type": event_type} if event_type else {}),
|
||||
).metrics
|
||||
self.assertEqual([r.metric for r in res], metrics)
|
||||
self.assertTrue(all(r.variants == ["Test variant"] for r in res))
|
||||
|
||||
def _assert_task_metrics(self, tasks: dict, event_type: str):
|
||||
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
|
||||
for task, metrics in tasks.items():
|
||||
@@ -122,6 +150,15 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(value.metric, metric)
|
||||
self.assertEqual(value.variant, variant)
|
||||
self.assertEqual(value.value, 0)
|
||||
# test metrics parameter
|
||||
res = self.api.events.get_task_single_value_metrics(
|
||||
tasks=[task], metrics=[{"metric": metric, "variants": [variant]}]
|
||||
).tasks
|
||||
self.assertEqual(len(res), 1)
|
||||
res = self.api.events.get_task_single_value_metrics(
|
||||
tasks=[task], metrics=[{"metric": "non_existing", "variants": [variant]}]
|
||||
).tasks
|
||||
self.assertEqual(len(res), 0)
|
||||
|
||||
# update is working
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
@@ -248,6 +285,15 @@ class TestTaskEvents(TestService):
|
||||
|
||||
self._assert_log_events(task=task, expected_total=1)
|
||||
|
||||
metrics = self.api.events.get_multi_task_metrics(
|
||||
tasks=[model],
|
||||
event_type="training_stats_scalar",
|
||||
model_events=True,
|
||||
).metrics
|
||||
self.assertEqual([m.metric for m in metrics], [f"Metric{i}" for i in range(5)])
|
||||
variants = [f"Variant{i}" for i in range(5)]
|
||||
self.assertTrue(all(m.variants == variants for m in metrics))
|
||||
|
||||
def test_error_events(self):
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
@@ -340,6 +386,30 @@ class TestTaskEvents(TestService):
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
def test_task_unique_metric_variants(self):
|
||||
project = self._temp_project()
|
||||
task1 = self._temp_task(project=project)
|
||||
task2 = self._temp_task(project=project)
|
||||
metric1 = "Metric1"
|
||||
metric2 = "Metric2"
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, 0),
|
||||
"metric": metric,
|
||||
"variant": "Variant",
|
||||
"value": 10,
|
||||
}
|
||||
for task, metric in ((task1, metric1), (task2, metric2))
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
metrics = self.api.projects.get_unique_metric_variants(project=project).metrics
|
||||
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
|
||||
metrics = self.api.projects.get_unique_metric_variants(ids=[task1, task2]).metrics
|
||||
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
|
||||
metrics = self.api.projects.get_unique_metric_variants(ids=[task1]).metrics
|
||||
self.assertEqual([m.metric for m in metrics], [metric1])
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
@@ -395,6 +465,25 @@ class TestTaskEvents(TestService):
|
||||
iterations=iter_count,
|
||||
)
|
||||
|
||||
# test metrics
|
||||
data = self.api.events.multi_task_scalar_metrics_iter_histogram(
|
||||
tasks=tasks,
|
||||
metrics=[
|
||||
{
|
||||
"metric": f"Metric{m_idx}",
|
||||
"variants": [f"Variant{v_idx}" for v_idx in range(4)],
|
||||
}
|
||||
for m_idx in range(2)
|
||||
],
|
||||
)
|
||||
self._assert_metrics_and_variants(
|
||||
data.metrics,
|
||||
metrics=2,
|
||||
variants=4,
|
||||
tasks=tasks,
|
||||
iterations=iter_count,
|
||||
)
|
||||
|
||||
def _assert_metrics_and_variants(
|
||||
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
|
||||
):
|
||||
@@ -515,6 +604,13 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
|
||||
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
|
||||
|
||||
# test metrics
|
||||
plots = self.api.events.get_multi_task_plots(
|
||||
tasks=[task1, task2], metrics=[{"metric": "A"}]
|
||||
).plots
|
||||
self.assertEqual(len(plots), 1)
|
||||
self.assertEqual(len(plots.A), 2)
|
||||
|
||||
def test_task_plots(self):
|
||||
task = self._temp_task()
|
||||
event = self._create_task_event("plot", task, 0)
|
||||
|
||||
@@ -12,17 +12,17 @@ class TestTasksFiltering(TestService):
|
||||
param1 = ("Se$tion1", "pa__ram1", True)
|
||||
param2 = ("Section2", "param2", False)
|
||||
task_count = 5
|
||||
for p in (param1, param2):
|
||||
for (section, name, unique_value) in (param1, param2):
|
||||
for idx in range(task_count):
|
||||
t = self.temp_task(project=project)
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=t,
|
||||
hyperparams=[
|
||||
{
|
||||
"section": p[0],
|
||||
"name": p[1],
|
||||
"section": section,
|
||||
"name": name,
|
||||
"type": "str",
|
||||
"value": str(idx) if p[2] else "Constant",
|
||||
"value": str(idx) if unique_value else "Constant",
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -42,6 +42,18 @@ class TestTasksFiltering(TestService):
|
||||
self.assertEqual(res.total, 0)
|
||||
self.assertEqual(res["values"], [])
|
||||
|
||||
# search pattern
|
||||
res = self.api.projects.get_hyperparam_values(
|
||||
projects=[project], section=param1[0], name=param1[1], pattern="^1"
|
||||
)
|
||||
self.assertEqual(res.total, 1)
|
||||
self.assertEqual(res["values"], ["1"])
|
||||
|
||||
res = self.api.projects.get_hyperparam_values(
|
||||
projects=[project], section=param1[0], name=param1[1], pattern="11"
|
||||
)
|
||||
self.assertEqual(res.total, 0)
|
||||
|
||||
def test_datetime_queries(self):
|
||||
tasks = [self.temp_task() for _ in range(5)]
|
||||
now = datetime.utcnow()
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import time
|
||||
from uuid import uuid4
|
||||
from datetime import timedelta
|
||||
from operator import attrgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apierrors.errors import bad_request
|
||||
from apiserver.tests.automated import TestService, utc_now_tz_aware
|
||||
from apiserver.tests.automated import TestService
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -72,7 +70,9 @@ class TestWorkersService(TestService):
|
||||
self.assertEqual(worker.tags, [tag])
|
||||
self.assertEqual(worker.system_tags, [system_tag])
|
||||
|
||||
workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers
|
||||
workers = self.api.workers.get_all(
|
||||
tags=[tag], system_tags=[f"-{system_tag}"]
|
||||
).workers
|
||||
self.assertFalse(workers)
|
||||
|
||||
def test_filters(self):
|
||||
@@ -83,7 +83,7 @@ class TestWorkersService(TestService):
|
||||
self._check_exists(test_worker, False, tags=["test"])
|
||||
self._check_exists(test_worker, False, tags=["-application"])
|
||||
|
||||
def _simulate_workers(self) -> Sequence[str]:
|
||||
def _simulate_workers(self, start: int) -> Sequence[str]:
|
||||
"""
|
||||
Two workers writing the same metrics. One for 4 seconds. Another one for 2
|
||||
The first worker reports a task
|
||||
@@ -105,25 +105,23 @@ class TestWorkersService(TestService):
|
||||
(workers[0],),
|
||||
(workers[0],),
|
||||
]
|
||||
timestamp = start * 1000
|
||||
for ws, stats in zip(workers_activity, workers_stats):
|
||||
for w, s in zip(ws, stats):
|
||||
data = dict(
|
||||
worker=w,
|
||||
timestamp=int(utc_now_tz_aware().timestamp() * 1000),
|
||||
timestamp=timestamp,
|
||||
machine_stats=s,
|
||||
)
|
||||
if w == workers[0]:
|
||||
data["task"] = task_id
|
||||
self.api.workers.status_report(**data)
|
||||
time.sleep(1)
|
||||
timestamp += 1000
|
||||
|
||||
res = self.api.workers.get_all(last_seen=100)
|
||||
return [w.key for w in res.workers]
|
||||
return workers
|
||||
|
||||
def _create_running_task(self, task_name):
|
||||
task_input = dict(
|
||||
name=task_name, type="testing"
|
||||
)
|
||||
task_input = dict(name=task_name, type="testing")
|
||||
|
||||
task_id = self.create_temp("tasks", **task_input)
|
||||
|
||||
@@ -131,7 +129,8 @@ class TestWorkersService(TestService):
|
||||
return task_id
|
||||
|
||||
def test_get_keys(self):
|
||||
workers = self._simulate_workers()
|
||||
workers = self._simulate_workers(int(time.time()))
|
||||
time.sleep(5) # give to es time to refresh
|
||||
res = self.api.workers.get_metric_keys(worker_ids=workers)
|
||||
assert {"cpu", "memory"} == set(c.name for c in res["categories"])
|
||||
assert all(
|
||||
@@ -147,11 +146,12 @@ class TestWorkersService(TestService):
|
||||
self.api.workers.get_metric_keys(worker_ids=["Non existing worker id"])
|
||||
|
||||
def test_get_stats(self):
|
||||
workers = self._simulate_workers()
|
||||
|
||||
to_date = utc_now_tz_aware() + timedelta(seconds=10)
|
||||
from_date = to_date - timedelta(days=1)
|
||||
start = int(time.time())
|
||||
workers = self._simulate_workers(start)
|
||||
|
||||
time.sleep(5) # give to ES time to refresh
|
||||
from_date = start
|
||||
to_date = start + 10
|
||||
# no variants
|
||||
res = self.api.workers.get_stats(
|
||||
items=[
|
||||
@@ -160,68 +160,58 @@ class TestWorkersService(TestService):
|
||||
dict(key="memory_used", aggregation="max"),
|
||||
dict(key="memory_used", aggregation="min"),
|
||||
],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
from_date=from_date,
|
||||
to_date=to_date,
|
||||
# split_by_variant=True,
|
||||
interval=1,
|
||||
worker_ids=workers,
|
||||
)
|
||||
self.assertWorkersInStats(workers, res["workers"])
|
||||
assert all(
|
||||
{"cpu_usage", "memory_used"}
|
||||
== set(map(attrgetter("metric"), worker["metrics"]))
|
||||
for worker in res["workers"]
|
||||
)
|
||||
|
||||
def _check_dates_and_stats(metric, stats, worker_id) -> bool:
|
||||
return set(
|
||||
map(attrgetter("aggregation"), metric["stats"])
|
||||
) == stats and len(metric["dates"]) == (4 if worker_id == workers[0] else 2)
|
||||
|
||||
assert all(
|
||||
_check_dates_and_stats(metric, metric_stats, worker["worker"])
|
||||
for worker in res["workers"]
|
||||
for metric, metric_stats in zip(
|
||||
worker["metrics"], ({"avg", "max"}, {"max", "min"})
|
||||
self.assertWorkersInStats(workers, res.workers)
|
||||
for worker in res.workers:
|
||||
self.assertEqual(
|
||||
set(metric.metric for metric in worker.metrics),
|
||||
{"cpu_usage", "memory_used"},
|
||||
)
|
||||
)
|
||||
|
||||
for worker in res.workers:
|
||||
for metric, metric_stats in zip(
|
||||
worker.metrics, ({"avg", "max"}, {"max", "min"})
|
||||
):
|
||||
self.assertEqual(
|
||||
set(stat.aggregation for stat in metric.stats), metric_stats
|
||||
)
|
||||
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
|
||||
|
||||
# split by variants
|
||||
res = self.api.workers.get_stats(
|
||||
items=[dict(key="cpu_usage", aggregation="avg")],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
from_date=from_date,
|
||||
to_date=to_date,
|
||||
split_by_variant=True,
|
||||
interval=1,
|
||||
worker_ids=workers,
|
||||
)
|
||||
self.assertWorkersInStats(workers, res["workers"])
|
||||
self.assertWorkersInStats(workers, res.workers)
|
||||
|
||||
def _check_metric_and_variants(worker):
|
||||
return (
|
||||
all(
|
||||
_check_dates_and_stats(metric, {"avg"}, worker["worker"])
|
||||
for metric in worker["metrics"]
|
||||
for worker in res.workers:
|
||||
for metric in worker.metrics:
|
||||
self.assertEqual(
|
||||
set(metric.variant for metric in worker.metrics),
|
||||
{"0", "1"} if worker.worker == workers[0] else {"0"},
|
||||
)
|
||||
and set(map(attrgetter("variant"), worker["metrics"])) == {"0", "1"}
|
||||
if worker["worker"] == workers[0]
|
||||
else {"0"}
|
||||
)
|
||||
|
||||
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
|
||||
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
|
||||
|
||||
res = self.api.workers.get_stats(
|
||||
items=[dict(key="cpu_usage", aggregation="avg")],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
from_date=from_date,
|
||||
to_date=to_date,
|
||||
interval=1,
|
||||
worker_ids=["Non existing worker id"],
|
||||
)
|
||||
assert not res["workers"]
|
||||
assert not res.workers
|
||||
|
||||
@staticmethod
|
||||
def assertWorkersInStats(workers: Sequence[str], stats: dict):
|
||||
assert set(workers) == set(map(attrgetter("worker"), stats))
|
||||
def assertWorkersInStats(self, workers: Sequence[str], stats: Sequence):
|
||||
self.assertEqual(set(workers), set(item.worker for item in stats))
|
||||
|
||||
def test_get_activity_report(self):
|
||||
# test no workers data
|
||||
@@ -232,28 +222,19 @@ class TestWorkersService(TestService):
|
||||
# to_timestamp=to_timestamp.timestamp(),
|
||||
# interval=20,
|
||||
# )
|
||||
start = int(time.time())
|
||||
self._simulate_workers(int(time.time()))
|
||||
|
||||
self._simulate_workers()
|
||||
|
||||
to_date = utc_now_tz_aware() + timedelta(seconds=10)
|
||||
from_date = to_date - timedelta(minutes=1)
|
||||
|
||||
time.sleep(5) # give to es time to refresh
|
||||
# no variants
|
||||
res = self.api.workers.get_activity_report(
|
||||
from_date=from_date.timestamp(), to_date=to_date.timestamp(), interval=20
|
||||
from_date=start, to_date=start + 10, interval=2
|
||||
)
|
||||
self.assertWorkerSeries(res["total"], 2)
|
||||
self.assertWorkerSeries(res["active"], 1)
|
||||
self.assertTotalSeriesGreaterThenActive(res["total"], res["active"])
|
||||
self.assertWorkerSeries(res["total"], 2, 5)
|
||||
self.assertWorkerSeries(res["active"], 1, 5)
|
||||
|
||||
@staticmethod
|
||||
def assertTotalSeriesGreaterThenActive(total_data: dict, active_data: dict):
|
||||
assert total_data["dates"][-1] == active_data["dates"][-1]
|
||||
assert total_data["counts"][-1] > active_data["counts"][-1]
|
||||
|
||||
@staticmethod
|
||||
def assertWorkerSeries(series_data: dict, min_count: int):
|
||||
assert len(series_data["dates"]) == len(series_data["counts"])
|
||||
# check the last 20s aggregation
|
||||
# there may be more workers that we created since we are not filtering by test workers here
|
||||
assert series_data["counts"][-1] >= min_count
|
||||
def assertWorkerSeries(self, series_data: dict, count: int, size: int):
|
||||
self.assertEqual(len(series_data["dates"]), size)
|
||||
self.assertEqual(len(series_data["counts"]), size)
|
||||
self.assertTrue(any(c == count for c in series_data["counts"]))
|
||||
self.assertTrue(all(c <= count for c in series_data["counts"]))
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from boltons.dictutils import OneToOne
|
||||
from mongoengine.queryset.transform import MATCH_OPERATORS
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class ParameterKeyEscaper:
|
||||
"""
|
||||
@@ -15,8 +17,13 @@ class ParameterKeyEscaper:
|
||||
@classmethod
|
||||
def escape(cls, value: str):
|
||||
""" Quote a parameter key """
|
||||
value = value.strip().replace("%", "%%")
|
||||
value = value.strip()
|
||||
if not value:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Empty key is not allowed"
|
||||
)
|
||||
|
||||
value = value.replace("%", "%%")
|
||||
for c, r in cls._mapping.items():
|
||||
value = value.replace(c, r)
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.12.0"
|
||||
__version__ = "1.14.0"
|
||||
|
||||
@@ -155,6 +155,7 @@ services:
|
||||
- http://fileserver:8081
|
||||
volumes:
|
||||
- c:/opt/clearml/logs:/var/log/clearml
|
||||
- c:/opt/clearml/config:/opt/clearml/config
|
||||
|
||||
networks:
|
||||
backend:
|
||||
|
||||
@@ -154,6 +154,7 @@ services:
|
||||
- http://fileserver:8081
|
||||
volumes:
|
||||
- /opt/clearml/logs:/var/log/clearml
|
||||
- /opt/clearml/config:/opt/clearml/config
|
||||
|
||||
agent-services:
|
||||
networks:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
boltons>=19.1.0
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
flask>=2.3.2
|
||||
flask>=2.3.3
|
||||
gunicorn>=20.1.0
|
||||
pyhocon>=0.3.35
|
||||
setuptools>=65.5.1
|
||||
urllib3>=1.26.18
|
||||
werkzeug>=3.0.1
|
||||
Reference in New Issue
Block a user