Compare commits

21 Commits

Author SHA1 Message Date
allegroai
4684fd5b74 Version bump to v1.13.0 2023-11-17 09:49:26 +02:00
allegroai
e08123fcc0 Fix workers.activity_report should return 0s for the time when no workers reported 2023-11-17 09:49:18 +02:00
allegroai
e713e876eb Upgrade urllib3 requirement 2023-11-17 09:48:19 +02:00
allegroai
c2cc788319 Added supported API versions doc 2023-11-17 09:47:44 +02:00
allegroai
da8315d0db Allow queries on the list of execution queue ids in tasks.get_all/get_all_ex 2023-11-17 09:47:19 +02:00
allegroai
4ac6f88278 Optimize Workers retrieval
Store worker statistics under worker id and not internal redis key
Fix unit tests
2023-11-17 09:46:44 +02:00
allegroai
a7865ccbec Turn on async task events deletion in case there are more than 100_000 events 2023-11-17 09:45:55 +02:00
allegroai
ec14f327c6 Optimize endpoints that do not require authorization by not validating JWT token 2023-11-17 09:45:22 +02:00
allegroai
a03b24d6b6 Add log info on caller IP if token validation fails 2023-11-17 09:43:59 +02:00
allegroai
cb71ef8e47 Fix missing scroll_id in events.get_scalar_metric_data 2023-11-17 09:43:11 +02:00
allegroai
8678fbc995 Fix properly unset Task fields on task reset 2023-11-17 09:42:39 +02:00
allegroai
58df8f201a Update API to 2.27 2023-11-17 09:40:34 +02:00
allegroai
f4bf16c156 Fix schema for swagger compatibility 2023-11-17 09:39:52 +02:00
allegroai
942f996237 Fix async_delete cannot be configured using configuration files 2023-11-17 09:39:22 +02:00
allegroai
c1e7f8f9c1 Optimize deletion of projects with many tasks 2023-11-17 09:38:32 +02:00
allegroai
274c487b37 Add update_tags api to tasks and models 2023-11-17 09:37:25 +02:00
allegroai
cc0129a800 Add filters parameter for passing user defined list filters for all get_all_ex apis 2023-11-17 09:36:58 +02:00
allegroai
388dd1b01f Fix regression issue with archive tasks display 2023-11-17 09:35:55 +02:00
allegroai
d62ecb5e6e Add last_change and last_change_by DB Model 2023-11-17 09:35:22 +02:00
allegroai
6d507616b3 Add pattern parameter to projects.get_hyperparam_values 2023-11-17 09:34:13 +02:00
allegroai
d0252a6dd9 Make sure that hyperparam/configuration/metadata keys that are contain only empty space are rejected 2023-11-17 09:32:22 +02:00
56 changed files with 1274 additions and 434 deletions

View File

@@ -72,6 +72,7 @@ class MultiProjectPagedRequest(MultiProjectRequest):
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
pattern = fields.StringField()
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):
@@ -98,3 +99,4 @@ class ProjectsGetRequest(models.Base):
allow_public = fields.BoolField(default=True)
children_type = ActualEnumField(ProjectChildrenType)
children_tags = fields.ListField(str)
children_tags_filter = DictField()

View File

@@ -333,3 +333,8 @@ class DeleteModelsRequest(TaskRequest):
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)
class UpdateTagsRequest(BatchRequest):
add_tags = ListField([str])
remove_tags = ListField([str])

View File

@@ -13,8 +13,7 @@ from jsonmodels.fields import (
from jsonmodels.models import Base
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
from apiserver.config_repo import config
class WorkerRequest(Base):
@@ -24,7 +23,10 @@ class WorkerRequest(Base):
class RegisterRequest(WorkerRequest):
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min)
timeout = IntField(
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
)
""" registration timeout in seconds (default is 10min) """
queues = ListField(six.string_types) # list of queues this worker listens to

View File

@@ -5,7 +5,6 @@ import zlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import elasticsearch
@@ -24,6 +23,7 @@ from apiserver.bll.event.event_common import (
get_metric_variants_condition,
uncompress_plot,
get_max_metric_and_variant_counts,
PlotFields,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
@@ -47,21 +47,15 @@ from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
EVENT_TYPES: Set[str] = set(et.value for et in EventType if et != EventType.all)
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
MAX_LONG = 2**63 - 1
MIN_LONG = -(2**63)
log = config.logger(__file__)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
source_urls = "source_urls"
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000)
class EventBLL(object):
@@ -130,7 +124,11 @@ class EventBLL(object):
return res
def add_events(
self, company_id, events, worker
self,
company_id: str,
user_id: str,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
task_ids = {}
model_ids = {}
@@ -268,11 +266,13 @@ class EventBLL(object):
else:
used_task_ids.add(task_or_model_id)
self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event,
last_events=task_last_events[task_or_model_id],
event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id], event=event,
last_events=task_last_scalar_events[task_or_model_id],
event=event,
)
actions.append(es_action)
@@ -311,20 +311,23 @@ class EventBLL(object):
else:
errors_per_type["Error when indexing events batch"] += 1
now = datetime.utcnow()
for model_id in used_model_ids:
ModelBLL.update_statistics(
company_id=company_id,
user_id=user_id,
model_id=model_id,
last_update=now,
last_iteration_max=task_iteration.get(model_id),
last_scalar_events=task_last_scalar_events.get(model_id),
)
remaining_tasks = set()
now = datetime.utcnow()
for task_id in used_task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
user_id=user_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
@@ -336,7 +339,12 @@ class EventBLL(object):
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
TaskBLL.set_last_update(
remaining_tasks,
company_id=company_id,
user_id=user_id,
last_update=now,
)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
@@ -466,9 +474,10 @@ class EventBLL(object):
def _update_task(
self,
company_id,
task_id,
now,
company_id: str,
user_id: str,
task_id: str,
now: datetime,
iter_max=None,
last_scalar_events=None,
last_events=None,
@@ -484,8 +493,9 @@ class EventBLL(object):
return False
return TaskBLL.update_statistics(
task_id,
company_id,
task_id=task_id,
company_id=company_id,
user_id=user_id,
last_update=now,
last_iteration_max=iter_max,
last_scalar_events=last_scalar_events,
@@ -569,7 +579,8 @@ class EventBLL(object):
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // last_iterations_per_plot)
@@ -636,9 +647,11 @@ class EventBLL(object):
return events, total_events, next_scroll_id
def get_debug_image_urls(
self, company_id: str, task_id: str, after_key: dict = None
self, company_id: str, task_ids: Sequence[str], after_key: dict = None
) -> Tuple[Sequence[str], Optional[dict]]:
if check_empty_data(self.es, company_id, EventType.metrics_image):
if not task_ids or check_empty_data(
self.es, company_id, EventType.metrics_image
):
return [], None
es_req = {
@@ -654,7 +667,10 @@ class EventBLL(object):
},
"query": {
"bool": {
"must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
"must": [
{"terms": {"task": task_ids}},
{"exists": {"field": "url"}},
]
}
},
}
@@ -672,9 +688,13 @@ class EventBLL(object):
return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
self, company_id: str, task_ids: Sequence[str], scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
if scroll_id == self.empty_scroll:
if (
scroll_id == self.empty_scroll
or not task_ids
or check_empty_data(self.es, company_id, EventType.metrics_plot)
):
return [], None
if scroll_id:
@@ -689,7 +709,7 @@ class EventBLL(object):
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"task": task_ids}},
{"exists": {"field": PlotFields.source_urls}},
]
}
@@ -825,7 +845,8 @@ class EventBLL(object):
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
es_req = {
"size": 0,
@@ -879,7 +900,8 @@ class EventBLL(object):
}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@@ -1023,9 +1045,9 @@ class EventBLL(object):
"order": {"_key": "desc"},
}
}
}
},
}
}
},
}
},
"query": {"bool": {"must": must}},
@@ -1091,7 +1113,10 @@ class EventBLL(object):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_ids, event_type=event_type, body=es_req,
self.es,
company_id=company_ids,
event_type=event_type,
body=es_req,
)
if "aggregations" not in es_res:
@@ -1142,18 +1167,26 @@ class EventBLL(object):
return {"refresh": True}
def delete_task_events(
self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
):
def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
if model:
self._validate_model_state(
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
company_id=company_id,
model_id=task_id,
allow_locked=allow_locked,
)
else:
self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked
)
async_delete = async_task_events_delete
if async_delete:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=[task_id],
)
if total <= async_delete_threshold:
async_delete = False
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context():
es_res = delete_company_events(
@@ -1211,14 +1244,23 @@ class EventBLL(object):
return es_res.get("deleted", 0)
def delete_multi_task_events(
self, company_id: str, task_ids: Sequence[str], async_delete=False
self, company_id: str, task_ids: Sequence[str], model=False
):
"""
Delete mutliple task events. No check is done for tasks write access
Delete multiple task events. No check is done for tasks write access
so it should be checked by the calling code
"""
deleted = 0
with translate_errors_context():
async_delete = async_task_events_delete
if async_delete and len(task_ids) < 100:
total = self.events_iterator.count_task_events(
event_type=EventType.all,
company_id=company_id,
task_ids=task_ids,
)
if total <= async_delete_threshold:
async_delete = False
for tasks in chunked_iter(task_ids, 100):
es_req = {"query": {"terms": {"task": tasks}}}
es_res = delete_company_events(
@@ -1232,7 +1274,7 @@ class EventBLL(object):
deleted += es_res.get("deleted", 0)
if not async_delete:
return es_res.get("deleted", 0)
return deleted
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:

View File

@@ -64,13 +64,13 @@ class EventsIterator:
self,
event_type: EventType,
company_id: str,
task_id: str,
task_ids: Sequence[str],
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
query, _ = self._get_initial_query_and_must(task_ids, metric_variants)
es_req = {
"query": query,
}
@@ -100,7 +100,7 @@ class EventsIterator:
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
query, must = self._get_initial_query_and_must([task_id], metric_variants)
# retrieve the next batch of events
es_req = {
@@ -158,14 +158,14 @@ class EventsIterator:
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
task_ids: Sequence[str], metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
query = {"terms": {"task": task_ids}}
must = [query]
else:
must = [
{"term": {"task": task_id}},
{"terms": {"task": task_ids}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}

View File

@@ -80,7 +80,14 @@ class ModelBLL:
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
now = datetime.utcnow()
updated = model.update(
upsert=False,
ready=True,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
return updated, published_task
@classmethod
@@ -125,6 +132,7 @@ class ModelBLL:
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
"last_changed_by": user_id,
},
},
array_filters=[{"elem.model": model_id}],
@@ -132,7 +140,9 @@ class ModelBLL:
)
else:
task.update(
pull__models__output__model=model_id, set__last_change=now
pull__models__output__model=model_id,
set__last_change=now,
set__last_changed_by=user_id,
)
delete_external_artifacts = delete_external_artifacts and config.get(
@@ -167,25 +177,29 @@ class ModelBLL:
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
def archive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return unarchived
@@ -218,11 +232,18 @@ class ModelBLL:
@staticmethod
def update_statistics(
company_id: str,
user_id: str,
model_id: str,
last_update: datetime = None,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
updates = {"last_update": datetime.utcnow()}
last_update = last_update or datetime.utcnow()
updates = {
"last_update": datetime.utcnow(),
"last_change": last_update,
"last_changed_by": user_id,
}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)

View File

@@ -1,8 +1,11 @@
from collections import defaultdict
from enum import Enum
from typing import Sequence, Dict
from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@@ -22,6 +25,51 @@ class OrgBLL:
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags(
self,
company_id: str,
@@ -50,10 +98,10 @@ class OrgBLL:
return ret
def update_tags(
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, project, tags, system_tags)
tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)

View File

@@ -107,7 +107,7 @@ class _TagsCache:
return ret
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
"""
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
@@ -123,7 +123,7 @@ class _TagsCache:
if not fields:
return
self._delete_redis_keys(company_id, projects=[project], fields=fields)
self._delete_redis_keys(company_id, projects=projects, fields=fields)
def reset_tags(self, company_id: str, projects: Sequence[str]):
self._delete_redis_keys(

View File

@@ -315,11 +315,12 @@ class ProjectBLL:
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
extra = {}
if hasattr(entity_cls, "last_change"):
extra["set__last_change"] = datetime.utcnow()
if hasattr(entity_cls, "last_changed_by"):
extra["set__last_changed_by"] = user
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
@@ -550,7 +551,10 @@ class ProjectBLL:
@classmethod
def get_dataset_stats(
cls, company: str, project_ids: Sequence[str], users: Sequence[str] = None,
cls,
company: str,
project_ids: Sequence[str],
users: Sequence[str] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
@@ -584,7 +588,9 @@ class ProjectBLL:
@staticmethod
def _get_projects_children(
project_ids: Sequence[str], search_hidden: bool, allowed_ids: Sequence[str],
project_ids: Sequence[str],
search_hidden: bool,
allowed_ids: Sequence[str],
) -> Tuple[ProjectsChildren, Set[str]]:
child_projects = _get_sub_projects(
project_ids,
@@ -628,7 +634,9 @@ class ProjectBLL:
project_ids_with_children = set(project_ids)
if include_children:
child_projects, children_ids = cls._get_projects_children(
project_ids, search_hidden=True, allowed_ids=selected_project_ids,
project_ids,
search_hidden=True,
allowed_ids=selected_project_ids,
)
project_ids_with_children |= children_ids
@@ -902,6 +910,7 @@ class ProjectBLL:
allow_public: bool = True,
children_type: ProjectChildrenType = None,
children_tags: Sequence[str] = None,
children_tags_filter: dict = None,
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
@@ -922,11 +931,15 @@ class ProjectBLL:
query &= Q(user__in=users)
project_query = None
child_query = (
query & GetMixin.get_list_field_query("tags", children_tags)
if children_tags
else query
)
if children_tags_filter:
child_query = query & GetMixin.get_list_filter_query(
"tags", children_tags_filter
)
elif children_tags:
child_query = query & GetMixin.get_list_field_query("tags", children_tags)
else:
child_query = query
if children_type == ProjectChildrenType.dataset:
child_queries = {
Project: child_query
@@ -1086,39 +1099,54 @@ class ProjectBLL:
or_conditions = []
for field, field_filter in filter_.items():
if not (
field_filter
and isinstance(field_filter, list)
and all(isinstance(t, str) for t in field_filter)
):
if not (field_filter and isinstance(field_filter, (list, dict))):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
f"Non empty list or dictionary expected for the field: {field}"
)
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
field_conditions = {}
for action, values in helper.actions.items():
value = list(set(values))
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)
if (
helper.explicit_operator
and helper.global_operator == Q.OR
and len(field_conditions) > 1
):
or_conditions.append(
[{field: {op: cond}} for op, cond in field_conditions.items()]
if isinstance(field_filter, list):
if not all(isinstance(t, str) for t in field_filter):
raise errors.bad_request.ValidationError(
f"Only string values are allowed in the list filter: {field}"
)
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
op = (
Q.OR
if helper.explicit_operator and helper.global_operator == Q.OR
else Q.AND
)
db_query = {op: helper.actions}
else:
conditions[field] = field_conditions
helper = GetMixin.ListQueryFilter.from_data(field, field_filter)
db_query = helper.db_query
for op, actions in db_query.items():
field_conditions = {}
for action, values in actions.items():
value = list(set(values))
for key in reversed(action.split("__")):
value = {f"${key}": value}
field_conditions.update(value)
if op == Q.OR and len(field_conditions) > 1:
or_conditions.append(
{
"$or": [
{field: {db_modifier: cond}}
for db_modifier, cond in field_conditions.items()
]
}
)
else:
conditions[field] = field_conditions
if or_conditions:
if len(or_conditions) == 1:
conditions["$or"] = next(iter(or_conditions))
conditions = next(iter(or_conditions))
else:
conditions["$and"] = [{"$or": c} for c in or_conditions]
conditions["$and"] = [c for c in or_conditions]
return conditions

View File

@@ -30,7 +30,6 @@ from .sub_projects import _ids_with_children
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -83,7 +82,8 @@ def validate_project_delete(company: str, project_id: str):
ret["pipelines"] = 0
if dataset_ids:
datasets_with_data = Task.objects(
project__in=dataset_ids, system_tags__nin=[EntityVisibility.archived.value],
project__in=dataset_ids,
system_tags__nin=[EntityVisibility.archived.value],
).distinct("project")
ret["datasets"] = len(datasets_with_data)
else:
@@ -185,10 +185,10 @@ def delete_project(
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
@@ -217,7 +217,9 @@ def delete_project(
return res, affected
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
def _delete_tasks(
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
@@ -228,14 +230,24 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
if not tasks:
return 0, set(), set()
task_ids = {t.id for t in tasks}
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
task_ids = list({t.id for t in tasks})
now = datetime.utcnow()
Task.objects(parent__in=task_ids, project__nin=projects).update(
parent=None,
last_change=now,
last_changed_by=user,
)
Model.objects(task__in=task_ids, project__nin=projects).update(
task=None,
last_change=now,
last_changed_by=user,
)
event_urls, artifact_urls = set(), set()
event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
company, task_ids
)
artifact_urls = set()
for task in tasks:
event_urls.update(collect_debug_image_urls(company, task.id))
event_urls.update(collect_plot_image_urls(company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls.update(
{
@@ -245,15 +257,13 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
}
)
event_bll.delete_multi_task_events(
company, list(task_ids), async_delete=async_events_delete
)
event_bll.delete_multi_task_events(company, task_ids)
deleted = tasks.delete()
return deleted, event_urls, artifact_urls
def _delete_models(
company: str, projects: Sequence[str]
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
@@ -287,25 +297,31 @@ def _delete_models(
"status": TaskStatus.published,
},
update={
"$set": {"models.output.$[elem].model": deleted, "last_change": now,}
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
"last_changed_by": user,
}
},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
# update unpublished tasks
Task.objects(
id__in=model_tasks, project__nin=projects, status__ne=TaskStatus.published,
).update(pull__models__output__model__in=model_ids, set__last_change=now)
id__in=model_tasks,
project__nin=projects,
status__ne=TaskStatus.published,
).update(
pull__models__output__model__in=model_ids,
set__last_change=now,
set__last_changed_by=user,
)
event_urls, model_urls = set(), set()
for m in models:
event_urls.update(collect_debug_image_urls(company, m.id))
event_urls.update(collect_plot_image_urls(company, m.id))
if m.uri:
model_urls.add(m.uri)
event_bll.delete_multi_task_events(
company, model_ids, async_delete=async_events_delete
event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
company, model_ids
)
model_urls = {m.uri for m in models if m.uri}
event_bll.delete_multi_task_events(company, model_ids, model=True)
deleted = models.delete()
return deleted, event_urls, model_urls

View File

@@ -140,6 +140,7 @@ class ProjectQueries:
name: str,
include_subprojects: bool,
allow_public: bool = True,
pattern: str = None,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
@@ -164,7 +165,20 @@ class ProjectQueries:
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}_{page}_{page_size}"
redis_key = "_".join(
str(part)
for part in (
"hyperparam_values",
company_id,
"_".join(project_ids),
section,
name,
allow_public,
pattern,
page,
page_size,
)
)
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
@@ -176,14 +190,22 @@ class ProjectQueries:
if cached_res:
return cached_res
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
match_condition = {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
if pattern:
match_condition["$expr"] = {
"$regexMatch": {
"input": f"${key_path}.value",
"regex": pattern,
"options": "i",
}
},
}
pipeline = [
{"$match": match_condition},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},

View File

@@ -1,6 +1,5 @@
from .task_bll import TaskBLL
from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
)

View File

@@ -1,7 +1,7 @@
from datetime import timedelta, datetime
from time import sleep
from apiserver.bll.task import update_project_time
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.task.task import TaskStatus, Task
from apiserver.utilities.threads_manager import ThreadsManager
@@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
status_changed=now,
last_update=now,
last_change=now,
last_changed_by="__apiserver__",
)
if updated:
project_ids.add(task.project)

View File

@@ -12,6 +12,7 @@ from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -31,7 +32,10 @@ from apiserver.database.model.task.task import (
)
from apiserver.database.model import EntityVisibility
from apiserver.database.model.queue import Queue
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.database.utils import (
get_company_or_none_constraint,
id as create_id,
)
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
@@ -39,7 +43,6 @@ from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
get_last_metric_updates,
)
@@ -78,7 +81,11 @@ class TaskBLL:
@staticmethod
def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
@@ -313,7 +320,7 @@ class TaskBLL:
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
projects=[new_task.project],
tags=updated_tags,
system_tags=updated_system_tags,
)
@@ -356,6 +363,7 @@ class TaskBLL:
def set_last_update(
task_ids: Collection[str],
company_id: str,
user_id: str,
last_update: datetime,
**extra_updates,
):
@@ -376,6 +384,7 @@ class TaskBLL:
upsert=False,
last_update=last_update,
last_change=last_update,
last_changed_by=user_id,
**updates,
)
return count
@@ -384,6 +393,7 @@ class TaskBLL:
def update_statistics(
task_id: str,
company_id: str,
user_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
@@ -440,6 +450,7 @@ class TaskBLL:
ret = TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
user_id=user_id,
last_update=last_update,
**extra_updates,
)

View File

@@ -1,10 +1,10 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Set, Tuple
from typing import Sequence, Set, Tuple, Union
import attr
from boltons.iterutils import partition, bucketize, first
from boltons.iterutils import partition, bucketize, first, chunked_iter
from furl import furl
from mongoengine import NotUniqueError
from pymongo.errors import DuplicateKeyError
@@ -26,7 +26,6 @@ from apiserver.database.utils import id as db_id
log = config.logger(__file__)
event_bll = EventBLL()
async_events_delete = config.get("services.tasks.async_events_delete", False)
@attr.s(auto_attribs=True)
@@ -69,37 +68,47 @@ class CleanupResult:
)
def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
def collect_plot_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
urls = set()
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
next_scroll_id = None
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_ids=tasks, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
def collect_debug_image_urls(
company: str, task_or_model: Union[str, Sequence[str]]
) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
after_key = None
urls = set()
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company, task_id=task_or_model, after_key=after_key,
)
urls.update(res)
if not after_key:
break
task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
for tasks in chunked_iter(task_ids, 100):
after_key = None
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_ids=tasks,
after_key=after_key,
)
urls.update(res)
if not after_key:
break
return urls
@@ -122,7 +131,11 @@ supported_storage_types.update(
def _schedule_for_delete(
company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
company: str,
user: str,
task_id: str,
urls: Set[str],
can_delete_folders: bool,
) -> Set[str]:
urls_per_storage = bucketize(
urls,
@@ -222,8 +235,13 @@ def cleanup_task(
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
now = datetime.utcnow()
if update_children:
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id,
last_change=now,
last_changed_by=user,
)
deleted_models = 0
updated_models = 0
@@ -231,37 +249,41 @@ def cleanup_task(
if not models:
continue
if delete_output_models and allow_delete:
model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
for m_id in model_ids:
model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
if model_ids:
if return_file_urls or delete_external_artifacts:
event_urls.update(collect_debug_image_urls(task.company, m_id))
event_urls.update(collect_plot_image_urls(task.company, m_id))
try:
event_bll.delete_task_events(
task.company,
m_id,
allow_locked=True,
model=True,
async_delete=async_events_delete,
)
except errors.bad_request.InvalidModelId as ex:
log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
event_urls.update(collect_debug_image_urls(task.company, model_ids))
event_urls.update(collect_plot_image_urls(task.company, model_ids))
event_bll.delete_multi_task_events(
task.company,
model_ids,
model=True,
)
deleted_models += Model.objects(id__in=model_ids).delete()
deleted_models += Model.objects(id__in=list(model_ids)).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
Model.objects(id__in=list(in_use_model_ids)).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
continue
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id
task=deleted_task_id,
last_change=now,
last_changed_by=user,
)
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
Model.objects(id__in=[m.id for m in models]).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
event_bll.delete_task_events(
task.company, task.id, allow_locked=force, async_delete=async_events_delete
)
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
if delete_external_artifacts:
scheduled = _schedule_for_delete(
@@ -304,7 +326,8 @@ def verify_task_children_and_ouptuts(
model_fields = ["id", "ready", "uri"]
published_models, draft_models = partition(
Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
Model.objects(task=task.id).only(*model_fields),
key=attrgetter("ready"),
)
if not force and published_models:
raise errors.bad_request.TaskCannotBeDeleted(

View File

@@ -7,9 +7,9 @@ from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
@@ -79,13 +79,20 @@ def archive_task(
def unarchive_task(
task: str, company_id: str, user_id: str, status_message: str, status_reason: str,
task: str,
company_id: str,
user_id: str,
status_message: str,
status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task, company_id=company_id, only=("id",), requires_write_access=True,
task,
company_id=company_id,
only=("id",),
requires_write_access=True,
)
return task.update(
status_message=status_message,
@@ -345,11 +352,17 @@ def reset_task(
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
unset__started=1,
unset__completed=1,
unset__published=1,
unset__active_duration=1,
unset__enqueue_status=1,
)
if clear_all:
updates.update(
set__execution=Execution(), unset__script=1,
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
@@ -370,11 +383,6 @@ def reset_task(
status_message="reset",
user_id=user_id,
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
enqueue_status=None,
**updates,
)

View File

@@ -1,14 +1,13 @@
from datetime import datetime
from typing import Sequence, Union
import attr
import six
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.utilities.attrs import typed_attrs
@@ -158,16 +157,6 @@ def get_possible_status_changes(current_status):
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:

View File

@@ -1,6 +1,7 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from typing import (
Optional,
Callable,
@@ -8,11 +9,13 @@ from typing import (
Tuple,
Sequence,
TypeVar,
Union,
)
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model.project import Project
from apiserver.database.model.settings import Settings
@@ -77,3 +80,13 @@ def run_batch_operation(
}
)
return results, failures
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())

View File

@@ -5,13 +5,13 @@ from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from boltons.iterutils import partition, chunked_iter
from pyhocon import ConfigTree
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
from apiserver.apierrors.errors import bad_request, server_error
from apiserver.apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry,
WorkerEntry,
StatusReportRequest,
@@ -30,12 +30,14 @@ from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@@ -68,7 +70,7 @@ class WorkerBLL:
"""
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
queues = queues or []
with translate_errors_context():
@@ -141,8 +143,6 @@ class WorkerBLL:
try:
entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
@@ -150,15 +150,16 @@ class WorkerBLL:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
self.log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=entry.key,
worker_id=report.worker,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
)
now = datetime.utcnow()
entry.last_activity_time = now
entry.queue = report.queue
if report.queues:
@@ -175,6 +176,7 @@ class WorkerBLL:
last_worker_report=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)
@@ -253,18 +255,15 @@ class WorkerBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
for entry in self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
)
)
]
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set(
@@ -283,9 +282,7 @@ class WorkerBLL:
}
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
task_ids = task_ids.union(
filter(
None,
@@ -495,12 +492,15 @@ class WorkerBLL:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
for key in self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
for keys in chunked_iter(
self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
),
1000,
):
data = self.redis.get(key)
data = self.redis.mget(keys)
if data:
entries.append(WorkerEntry.from_json(data))
entries.extend(WorkerEntry.from_json(d) for d in data if d)
return entries
@@ -509,18 +509,17 @@ class WorkerBLL:
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es(
def log_stats_to_es(
self,
company_id: str,
company_name: str,
worker: str,
worker_id: str,
timestamp: int,
task: str,
machine_stats: MachineStats,
) -> bool:
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise
:return: The amount of logged documents
"""
es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
@@ -532,8 +531,7 @@ class WorkerBLL:
_index=es_index,
_source=dict(
timestamp=timestamp,
worker=worker,
company=company_name,
worker=worker_id,
task=task,
category=category,
metric=metric,
@@ -558,7 +556,7 @@ class WorkerBLL:
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return added
@attr.s(auto_attribs=True)

View File

@@ -215,6 +215,10 @@ class WorkerStats:
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}

View File

@@ -23,4 +23,6 @@ hyperparam_values {
max_last_metrics: 2000
# if set then call to tasks.delete/cleanup does not wait for ES events deletion
async_events_delete: false
async_events_delete: true
# do not use async_delete if the deleted task has amount of events lower than this threshold
async_events_delete_threshold: 100000

View File

@@ -1,5 +1,6 @@
import re
from collections import namedtuple, defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
Collection,
@@ -196,9 +197,7 @@ class GetMixin(PropsMixin):
if self.global_operator is None:
self.global_operator = self.default_operator
def _get_next_term(
self, data: Sequence[str]
) -> Generator[Term, None, None]:
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
for value in data:
if value is None:
@@ -232,12 +231,18 @@ class GetMixin(PropsMixin):
operator = self._operators.get(value)
if operator is None:
raise FieldsValueError(
"Unsupported operator", field=self._field, operator=value,
"Unsupported operator",
field=self._field,
operator=value,
)
yield self.Term(operator=operator)
continue
if not unary_operator and self._support_legacy and value.startswith("-"):
if (
not unary_operator
and self._support_legacy
and value.startswith("-")
):
value = value[1:]
if not value:
raise FieldsValueError(
@@ -402,12 +407,25 @@ class GetMixin(PropsMixin):
parameters = {
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
}
filters = parameters.pop("filters", {})
if not isinstance(filters, dict):
raise FieldsValueError(
"invalid value type, string expected",
field=filters,
value=str(filters),
)
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=filters
).items():
query &= cls.get_list_filter_query(field, data)
parameters.pop(field, None)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
@@ -531,6 +549,135 @@ class GetMixin(PropsMixin):
return q
@attr.s(auto_attribs=True)
class ListQueryFilter:
"""
Deserialize filters data and build db_query object that represents it with the corresponding
mongo engine operations
Each part has include and exclude lists that map to mongoengine operations as following:
"any"
- include -> 'in'
- exclude -> 'not_all'
- combined by 'or' operation
"all"
- include -> 'all'
- exclude -> 'nin'
- combined by 'and' operation
"op" optional parameter for combining "and" and "all" parts. Can be "and" or "or". The default is "and"
"""
_and_op = "and"
_or_op = "or"
_allowed_op = [_and_op, _or_op]
_db_modifiers: Mapping = {
(Q.OR, True): "in",
(Q.OR, False): "not__all",
(Q.AND, True): "all",
(Q.AND, False): "nin",
}
@attr.s(auto_attribs=True)
class ListFilter:
include: Sequence[str] = []
exclude: Sequence[str] = []
@classmethod
def from_dict(cls, d: Mapping):
if d is None:
return None
return cls(**d)
any: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
all: ListFilter = attr.ib(converter=ListFilter.from_dict, default=None)
op: str = attr.ib(default="and")
db_query: dict = attr.ib(init=False)
# noinspection PyUnresolvedReferences
@op.validator
def op_validator(self, _, value):
if value not in self._allowed_op:
raise ValueError(
f"Invalid list query filter operator: {value}. "
f"Should be one of {str(self._allowed_op)}"
)
@property
def and_op(self) -> bool:
return self.op == self._and_op
def __attrs_post_init__(self):
self.db_query = {}
for op, conditions in ((Q.OR, self.any), (Q.AND, self.all)):
if not conditions:
continue
operations = {}
for vals, include in (
(conditions.include, True),
(conditions.exclude, False),
):
if not vals:
continue
operations[self._db_modifiers[(op, include)]] = list(set(vals))
self.db_query[op] = operations
@classmethod
def from_data(cls, field, data: Mapping):
if not isinstance(data, dict):
raise errors.bad_request.ValidationError(
"invalid filter for field, dictionary expected",
field=field,
value=str(data),
)
try:
return cls(**data)
except Exception as ex:
raise errors.bad_request.ValidationError(
field=field,
value=str(ex),
)
@classmethod
def get_list_filter_query(
cls, field: str, data: Mapping
) -> Union[RegexQ, RegexQCombination]:
if not data:
return RegexQ()
filter_ = cls.ListQueryFilter.from_data(field, data)
mongoengine_field = field.replace(".", "__")
queries = []
for op, actions in filter_.db_query.items():
if not actions:
continue
ops = []
for action, vals in actions.items():
if not vals:
continue
ops.append(RegexQ(**{f"{mongoengine_field}__{action}": vals}))
if not ops:
continue
if len(ops) == 1:
queries.extend(ops)
continue
queries.append(RegexQCombination(operation=op, children=ops))
if not queries:
return RegexQ()
if len(queries) == 1:
return queries[0]
operation = Q.AND if filter_.and_op else Q.OR
return RegexQCombination(operation=operation, children=queries)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
@@ -639,7 +786,7 @@ class GetMixin(PropsMixin):
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
"""Extract a projection list from the provided dictionary. Supports an override projection."""
if override_projection is not None:
return override_projection
if not parameters:
@@ -653,7 +800,8 @@ class GetMixin(PropsMixin):
"""Return include and exclude lists based on passed projection and class definition"""
if projection:
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
projection,
key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
else:
include, exclude = [], []
@@ -900,7 +1048,9 @@ class GetMixin(PropsMixin):
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
query_dict=query_dict,
data_getter=data_getter,
ret_params=ret_params,
)
return cls._get_many_no_company(
@@ -913,7 +1063,9 @@ class GetMixin(PropsMixin):
@classmethod
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
cls,
query: Q = None,
projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
@@ -1206,7 +1358,7 @@ class UpdateMixin(object):
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
"""Provide convenience methods for a subclass of mongoengine.Document"""
@classmethod
def aggregate(
@@ -1234,25 +1386,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
def set_public(
cls: Type[Document],
company_id: str,
user_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, set__company="")
update: dict = dict(set__company_origin=company_id, set__company="")
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
update: dict = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
if hasattr(cls, "last_change"):
update["set__last_change"] = datetime.utcnow()
if hasattr(cls, "last_changed_by"):
update["set__last_changed_by"] = user_id
return {"updated": cls.objects(id__in=ids).update(**update)}

View File

@@ -90,6 +90,8 @@ class Model(AttributedDocument):
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
last_change = DateTimeField()
last_changed_by = StringField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)

View File

@@ -230,11 +230,12 @@ class Task(AttributedDocument):
"project",
"parent",
"hyperparams.*",
"execution.queue",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment", "report"),
fields=("execution.queue", "runtime.*", "models.input.model"),
fields=("runtime.*", "models.input.model"),
)
id = StringField(primary_key=True)

View File

@@ -0,0 +1,19 @@
### Supported api versions
| Release | ApiVersion |
|---------|------------|
| v1.13 | 2.27 |
| v1.12 | 2.26 |
| v1.11 | 2.25 |
| v1.10 | 2.24 |
| v1.9 | 2.23 |
| v1.8 | 2.22 |
| v1.7 | 2.21 |
| v1.6 | 2.20 |
| v1.5 | 2.19 |
| v1.4 | 2.18 |
| v1.3 | 2.17 |
| v1.2 | 2.16 |
| v1.1 | 2.15 |
| v1.0 | 2.14 |
| v0.17 | 2.13 |

View File

@@ -960,7 +960,7 @@ class PrePopulate:
return tasks
@classmethod
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str):
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
@@ -969,5 +969,8 @@ class PrePopulate:
ev["company_id"] = company_id
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id, events=events, worker=""
company_id=company_id,
user_id=user_id,
events=events,
worker="",
)

View File

@@ -33,4 +33,4 @@ semantic_version>=2.8.3,<3
setuptools>=65.5.1
six
validators>=0.12.4
urllib3>=1.26.16
urllib3>=1.26.18

View File

@@ -1,3 +1,43 @@
field_filter {
type: object
description: Filter on a field that includes combination of 'any' or 'all' included and excluded terms
properties {
any {
type: object
description: All the terms in 'any' condition are combined with 'or' operation
properties {
"include" {
type: array
items {type: string}
}
exclude {
type: array
items {type: string}
}
}
}
all {
type: object
description: All the terms in 'all' condition are combined with 'and' operation
properties {
"include" {
type: array
items {type: string}
}
exclude {
type: array
items {type: string}
}
}
}
op {
type: string
description: The operation between 'any' and 'all' parts of the filter if both are provided
default: and
enum: [and, or]
}
}
}
metadata_item {
type: object
properties {

View File

@@ -414,7 +414,7 @@ task {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
models {
description: "Task models"

View File

@@ -11,8 +11,8 @@ _definitions {
type: number
}
type {
description: "training_stats_vector"
const: "training_stats_scalar"
description: "'training_stats_scalar'"
type: string
}
task {
description: "Task ID (required)"
@@ -46,8 +46,8 @@ _definitions {
type: number
}
type {
description: "training_stats_vector"
const: "training_stats_vector"
description: "'training_stats_vector'"
type: string
}
task {
description: "Task ID (required)"
@@ -82,8 +82,8 @@ _definitions {
type: number
}
type {
description: ""
const: "training_debug_image"
description: "'training_debug_image'"
type: string
}
task {
description: "Task ID (required)"
@@ -123,7 +123,7 @@ _definitions {
}
type {
description: "'plot'"
const: "plot"
type: string
}
task {
description: "Task ID (required)"
@@ -221,7 +221,7 @@ _definitions {
}
type {
description: "'log'"
const: "log"
type: string
}
task {
description: "Task ID (required)"
@@ -1470,6 +1470,10 @@ get_scalar_metric_data {
type: string
description: type of metric
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
@@ -1492,7 +1496,7 @@ get_scalar_metric_data {
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
description: "Scroll ID for getting more results"
}
}
}

View File

@@ -6,7 +6,7 @@ _default {
}
supported_modes {
authorize: false
authorize: null
"2.9" {
description: """ Return supported login modes."""
request {
@@ -59,7 +59,7 @@ supported_modes {
description: "SSO authentication providers"
type: object
additionalProperties {
desctiprion: "Provider redirect URL"
description: "Provider redirect URL"
type: string
}
}
@@ -95,7 +95,7 @@ supported_modes {
}
logout {
authorize: false
authorize: null
allow_roles = [ "*" ]
"2.13" {
description: """ Logout (including SSO, if used)) """

View File

@@ -261,6 +261,14 @@ get_all_ex {
}
}
}
"2.27": ${get_all_ex."2.23"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.1" {
@@ -357,9 +365,6 @@ get_all {
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
@@ -1086,4 +1091,38 @@ delete_metadata {
}
}
}
}
update_tags {
"2.27" {
description: Add or remove tags from multiple models
request {
type: object
properties {
ids {
type: array
description: IDs of the models to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated models
}
}
}
}
}

View File

@@ -203,7 +203,7 @@ get_entities_count {
default: false
}
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
description: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}

View File

@@ -59,7 +59,7 @@ start_pipeline {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
value: { type: string }
}
}
}

View File

@@ -1,5 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
@@ -569,7 +570,7 @@ get_all_ex {
request {
properties {
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
description: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
@@ -660,6 +661,15 @@ get_all_ex {
items {type: string}
}
}
"2.27": ${get_all_ex."2.25"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
children_tags_filter: ${_definitions.field_filter}
}
}
}
update {
"2.1" {
@@ -1000,6 +1010,12 @@ get_hyperparam_values {
}
}
}
"2.27": ${get_hyperparam_values."2.26"} {
request.properties.pattern {
type: string
description: The search pattern regex
}
}
}
get_hyper_parameters {
"2.9" {
@@ -1270,13 +1286,15 @@ get_task_parents {
}
project {
type: object
id {
description: "The ID of the parent task project"
type: string
}
name {
description: "The name of the parent task project"
type: string
properties {
id {
description: "The ID of the parent task project"
type: string
}
name {
description: "The name of the parent task project"
type: string
}
}
}
}

View File

@@ -159,6 +159,14 @@ get_all_ex {
default: false
}
}
"2.27": ${get_all_ex."2.21"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.4" {

View File

@@ -578,7 +578,7 @@ get_task_data {
single_value_metrics {
type: object
description: If passed then task single value metrics are returned
additonalProperties: false
additionalProperties: false
}
}
response.properties.single_value_metrics {
@@ -694,9 +694,6 @@ get_all_ex {
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
@@ -720,6 +717,14 @@ get_all_ex {
default: false
}
}
"2.27": ${get_all_ex."2.26"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_tags {
"2.23" {

View File

@@ -190,6 +190,14 @@ get_all_ex {
}
}
}
"2.27": ${get_all_ex."2.23"} {
request.properties {
filters {
type: object
additionalProperties: ${_definitions.field_filter}
}
}
}
}
get_all {
"2.1" {
@@ -289,9 +297,6 @@ get_all {
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
@@ -481,7 +486,7 @@ clone {
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
}
}
@@ -659,7 +664,7 @@ create {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
}
}
@@ -748,7 +753,7 @@ validate {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
}
}
@@ -910,7 +915,7 @@ edit {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: [string, null] }
additionalProperties { type: string }
}
runtime {
description: "Task runtime mapping"
@@ -2050,3 +2055,37 @@ move {
}
}
}
update_tags {
"2.27" {
description: Add or remove tags from multiple tasks
request {
type: object
properties {
ids {
type: array
description: IDs of the tasks to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated tasks
}
}
}
}
}

View File

@@ -30,24 +30,35 @@ def get_auth_func(auth_type):
raise errors.unauthorized.BadAuthType()
def authorize_token(jwt_token, *_, **__):
def authorize_token(jwt_token, service, action, call):
"""Validate token against service/endpoint and requests data (dicts).
Returns a parsed token object (auth payload)
"""
call_info = {"ip": call.real_ip}
def log_error(msg):
info = ", ".join(f"{k}={v}" for k, v in call_info.items())
log.error(f"{msg} Call info: {info}")
try:
return Token.from_encoded_token(jwt_token)
except jwt.exceptions.InvalidKeyError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken(
"jwt invalid key error", reason=ex.args[0]
)
except jwt.InvalidTokenError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken("invalid jwt token", reason=ex.args[0])
except ValueError as ex:
log.exception("Failed while processing token: %s" % ex.args[0])
log_error(f"Failed while processing token: {str(ex.args[0])}.")
raise errors.unauthorized.InvalidToken(
"failed processing token", reason=ex.args[0]
)
except Exception:
log_error("Failed processing token.")
raise
def authorize_credentials(auth_data, service, action, call):

View File

@@ -90,7 +90,7 @@ class Token(Payload):
return token
except Exception as e:
raise errors.unauthorized.InvalidToken(
"failed parsing token, %s" % e.args[0]
"failed parsing token", reason=e.args[0]
)
@classmethod

View File

@@ -39,7 +39,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.26")
_max_version = PartialVersion("2.27")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -17,7 +17,7 @@ log = config.logger(__file__)
def validate_data(call: APICall, endpoint: Endpoint):
""" Perform all required call/endpoint validation, update call result appropriately """
try:
# todo: remove vaildate_required_fields once all endpoints have json schema
# todo: remove validate_required_fields once all endpoints have json schema
validate_required_fields(endpoint, call)
# set models. models will be validated automatically
@@ -50,10 +50,17 @@ def validate_role(endpoint, call):
pass
def validate_auth(endpoint, call):
""" Validate authorization for this endpoint and call.
If authentication has occurred, the call is updated with the authentication results.
def validate_auth(endpoint: Endpoint, call: "APICall"):
"""
Validate authorization for this endpoint and call.
If authentication has occurred, the call is updated with the authentication results.
For the endpoints with authorize==False the validation is not performed to improve performance
For the endpoints with authorize==True the validation should pass otherwise exception will be thrown
For the endpoints with authorize==None the validation will be tried, but it does not have to succeed
"""
if endpoint.authorize is not None and not endpoint.authorize:
return
if not call.authorization:
# No auth data. Invalid if we need to authorize and valid otherwise
if endpoint.authorize:
@@ -63,10 +70,9 @@ def validate_auth(endpoint, call):
# prepare arguments for validation
service, _, action = endpoint.name.partition(".")
# If we have auth data, we'll try to validate anyway (just so we'll have auth-based permissions whenever possible,
# even if endpoint did not require authorization)
# noinspection PyBroadException
try:
auth = call.authorization or ""
auth = call.authorization
auth_type, _, auth_data = auth.partition(" ")
authorize_func = get_auth_func(auth_type)
call.auth = authorize_func(auth_data, service, action, call)
@@ -78,7 +84,7 @@ def validate_auth(endpoint, call):
def validate_impersonation(endpoint, call):
""" Validate impersonation headers and set impersonated identity and authorization data accordingly.
:returns True if impersonating, False otherwise
:return: True if impersonating, False otherwise
"""
try:
act_as = call.act_as

View File

@@ -71,7 +71,12 @@ def _assert_task_or_model_exists(
@endpoint("events.add")
def add(call: APICall, company_id, _):
data = call.data.copy()
added, err_count, err_info = event_bll.add_events(company_id, [data], call.worker)
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
events=[data],
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -82,9 +87,10 @@ def add_batch(call: APICall, company_id, _):
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
company_id=company_id,
user_id=call.identity.user,
events=events,
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@@ -360,7 +366,7 @@ def get_task_events(_, company_id, request: TaskEventsRequest):
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task_or_model.get_index_company(),
task_id=task_id,
task_ids=[task_id],
metric_variants=metric_variants,
)
@@ -558,8 +564,8 @@ def get_multi_task_plots_v1_7(call, company_id, _):
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
list(companies),
task_ids,
company_id=list(companies),
task_id=task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
size=10000,
@@ -1085,7 +1091,7 @@ def scalar_metrics_iter_raw(
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task_or_model.get_index_company(),
task_id=task_id,
task_ids=[task_id],
metric_variants=metric_variants,
)

View File

@@ -22,6 +22,7 @@ from apiserver.apimodels.models import (
ModelsDeleteManyRequest,
ModelsGetRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
@@ -76,7 +77,9 @@ def get_by_id(call: APICall, company_id, _):
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
"no such public or company model",
id=model_id,
company=company_id,
)
conform_model_data(call, models[0])
call.result.data = {"model": models[0]}
@@ -102,7 +105,9 @@ def get_by_task_id(call: APICall, company_id, _):
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
"no such public or company model",
id=model_id,
company=company_id,
)
model_dict = model.to_proper_dict()
conform_model_data(call, model_dict)
@@ -128,7 +133,10 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
for model in models:
model["stats"] = stats.get(model["id"])
@@ -212,7 +220,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Model,
project=project,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
@@ -220,7 +228,9 @@ def _update_cached_tags(company: str, project: str, fields: dict):
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company, Tags.Model, projects=projects,
company,
Tags.Model,
projects=projects,
)
@@ -283,6 +293,8 @@ def update_for_task(call: APICall, company_id, _):
id=database.utils.id(),
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
user=call.identity.user,
company=company_id,
project=task.project,
@@ -301,6 +313,7 @@ def update_for_task(call: APICall, company_id, _):
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
models__output=[
ModelItem(
@@ -320,7 +333,6 @@ def update_for_task(call: APICall, company_id, _):
response_data_model=CreateModelResponse,
)
def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company_id = ""
@@ -345,6 +357,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
company=company_id,
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
**fields,
)
model.save()
@@ -414,12 +428,20 @@ def edit(call: APICall, company_id, _):
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
if fields:
now = datetime.utcnow()
fields.update(
last_change=now,
last_changed_by=call.identity.user,
)
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
fields.update(last_update=now)
updated = model.update(upsert=False, **fields)
if updated:
@@ -445,13 +467,25 @@ def _update_model(call: APICall, company_id, model_id=None):
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
now = datetime.utcnow()
updated_count, updated_fields = Model.safe_update(
company_id,
model.id,
data,
injected_update=dict(
last_change=now,
last_changed_by=call.identity.user,
),
)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
model.update(upsert=False, last_update=now)
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
@@ -573,7 +607,10 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
)
def archive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
func=partial(
ModelBLL.archive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
@@ -588,7 +625,8 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[
@@ -601,7 +639,11 @@ def unarchive_many(call: APICall, company_id, request: BatchRequest):
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=True,
)
@@ -610,7 +652,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=False,
)
@@ -633,30 +679,51 @@ def move(call: APICall, company_id: str, request: MoveRequest):
}
@endpoint("models.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
now = datetime.utcnow()
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=datetime.utcnow(),
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
return {
"updated": Metadata.delete_metadata(
model, keys=request.keys, last_update=datetime.utcnow()
model,
keys=request.keys,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}

View File

@@ -108,7 +108,13 @@ def _get_project_stats_filter(
if request.include_stats_filter or not request.children_type:
return request.include_stats_filter, request.search_hidden
stats_filter = {"tags": request.children_tags} if request.children_tags else {}
if request.children_tags_filter:
stats_filter = {"tags": request.children_tags_filter}
elif request.children_tags:
stats_filter = {"tags": request.children_tags}
else:
stats_filter = {}
if request.children_type == ProjectChildrenType.pipeline:
return (
{
@@ -153,6 +159,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
allow_public=allow_public,
children_type=request.children_type,
children_tags=request.children_tags,
children_tags_filter=request.children_tags_filter,
)
if not ids:
return {"projects": []}
@@ -452,6 +459,7 @@ def get_hyperparam_values(
name=request.name,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
pattern=request.pattern,
page=request.page,
page_size=request.page_size,
)
@@ -505,7 +513,11 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidProjectId,
enabled=True,
)
@@ -514,7 +526,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidProjectId,
enabled=False,
)

View File

@@ -67,6 +67,7 @@ from apiserver.apimodels.tasks import (
GetAllReq,
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@@ -76,7 +77,6 @@ from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
@@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import (
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import run_batch_operation
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output
@@ -112,7 +112,11 @@ from apiserver.database.model.task.task import (
ModelItem,
TaskModelTypes,
)
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.database.utils import (
get_fields_attr,
parse_from_call,
get_options,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
@@ -493,7 +497,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Task,
project=project,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
@@ -1232,9 +1236,12 @@ def completed(call: APICall, company_id, request: CompletedRequest):
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):
def ping(call: APICall, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
task_ids=[request.task],
company_id=company_id,
user_id=call.identity.user,
last_update=datetime.utcnow(),
)
@@ -1277,14 +1284,22 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidTaskId,
enabled=True,
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidTaskId,
enabled=False,
)
@@ -1314,8 +1329,21 @@ def move(call: APICall, company_id: str, request: MoveRequest):
return {"project_id": project_id}
@endpoint("tasks.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Task,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
models_field = f"models__{request.type}"
@@ -1326,6 +1354,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
updated = TaskBLL.update_statistics(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=request.iteration,
**({f"push__{models_field}": model} if not updated else {}),
)

View File

@@ -0,0 +1,67 @@
from apiserver.apierrors import errors
from apiserver.tests.automated import TestService
class TestGetAllExFilters(TestService):
def test_list_filters(self):
tags = ["a", "b", "c", "d"]
tasks = [self._temp_task(tags=tags[:i]) for i in range(len(tags) + 1)]
# invalid params check
with self.api.raises(errors.bad_request.ValidationError):
self.api.tasks.get_all_ex(filters={"tags": {"test": ["1"]}})
# test any condition
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"any": {"include": ["a", "b"]}}}
).tasks
self.assertEqual(set(tasks[1:]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"any": {"exclude": ["c", "d"]}}}
).tasks
self.assertEqual(set(tasks[:-1]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks,
filters={"tags": {"any": {"include": ["a", "b"], "exclude": ["c", "d"]}}},
).tasks
self.assertEqual(set(tasks), set(t.id for t in res))
# test all condition
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"all": {"include": ["a", "b"]}}}
).tasks
self.assertEqual(set(tasks[2:]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks, filters={"tags": {"all": {"exclude": ["c", "d"]}}}
).tasks
self.assertEqual(set(tasks[:-2]), set(t.id for t in res))
res = self.api.tasks.get_all_ex(
id=tasks,
filters={"tags": {"all": {"include": ["a", "b"], "exclude": ["c", "d"]}}},
).tasks
self.assertEqual([tasks[2]], [t.id for t in res])
# test combination
res = self.api.tasks.get_all_ex(
id=tasks,
filters={
"tags": {"any": {"include": ["a", "b"]}, "all": {"exclude": ["c", "d"]}}
},
).tasks
self.assertEqual(set(tasks[1:-2]), set(t.id for t in res))
def _temp_task(self, **kwargs):
self.update_missing(
kwargs,
name="test get_all_ex filters",
type="training",
)
return self.create_temp(
"tasks",
**kwargs,
delete_paramse=dict(can_fail=True, force=True),
)

View File

@@ -92,6 +92,29 @@ class TestProjectTags(TestService):
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)
def test_tags_api(self):
p = self.create_temp("projects", name="Test tags api", description="test")
# task
initial_tags = ["Task tag"]
task = self.new_task(project=p, tags=initial_tags)
data = self.api.projects.get_task_tags(projects=[p])
self.assertEqual(data.tags, initial_tags)
new_tags = ["New task tag"]
self.api.tasks.update_tags(ids=[task], add_tags=new_tags, remove_tags=initial_tags)
data = self.api.projects.get_task_tags(projects=[p])
self.assertEqual(data.tags, new_tags)
# model
initial_tags = ["Model tag"]
model = self.new_model(project=p, tags=initial_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(data.tags, initial_tags)
new_tags = ["New model tag"]
self.api.models.update_tags(ids=[model], add_tags=new_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
def new_task(self, **kwargs):
self.update_missing(
kwargs, type="testing", name="test project tags"

View File

@@ -64,6 +64,20 @@ class TestSubProjects(TestService):
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 2)
# new filter
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags_filter={"any": {"include": ["test1", "test2"]}},
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 2)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
@@ -77,6 +91,20 @@ class TestSubProjects(TestService):
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 1)
# new filter
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags_filter={"all": {"include": ["test1", "test2"]}},
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project2")
self.assertEqual(p.stats.active.total_tasks, 1)
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
@@ -102,6 +130,20 @@ class TestSubProjects(TestService):
for p in projects:
self.assertEqual(p.stats.active.total_tasks, 1)
# new filter
projects = self.api.projects.get_all_ex(
parent=[test_root],
children_type="report",
children_tags_filter={"all": {"exclude": ["test1", "test2"]}},
shallow_search=True,
include_stats=True,
check_own_contents=True,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
self.assertEqual(p.basename, "project1")
self.assertEqual(p.stats.active.total_tasks, 1)
def test_query_children(self):
test_root_name = "TestQueryChildren"
test_root = self._temp_project(name=test_root_name)

View File

@@ -12,17 +12,17 @@ class TestTasksFiltering(TestService):
param1 = ("Se$tion1", "pa__ram1", True)
param2 = ("Section2", "param2", False)
task_count = 5
for p in (param1, param2):
for (section, name, unique_value) in (param1, param2):
for idx in range(task_count):
t = self.temp_task(project=project)
self.api.tasks.edit_hyper_params(
task=t,
hyperparams=[
{
"section": p[0],
"name": p[1],
"section": section,
"name": name,
"type": "str",
"value": str(idx) if p[2] else "Constant",
"value": str(idx) if unique_value else "Constant",
}
],
)
@@ -42,6 +42,18 @@ class TestTasksFiltering(TestService):
self.assertEqual(res.total, 0)
self.assertEqual(res["values"], [])
# search pattern
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1], pattern="^1"
)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], ["1"])
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1], pattern="11"
)
self.assertEqual(res.total, 0)
def test_datetime_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()

View File

@@ -1,11 +1,9 @@
import time
from uuid import uuid4
from datetime import timedelta
from operator import attrgetter
from typing import Sequence
from apiserver.apierrors.errors import bad_request
from apiserver.tests.automated import TestService, utc_now_tz_aware
from apiserver.tests.automated import TestService
from apiserver.config_repo import config
log = config.logger(__file__)
@@ -72,7 +70,9 @@ class TestWorkersService(TestService):
self.assertEqual(worker.tags, [tag])
self.assertEqual(worker.system_tags, [system_tag])
workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers
workers = self.api.workers.get_all(
tags=[tag], system_tags=[f"-{system_tag}"]
).workers
self.assertFalse(workers)
def test_filters(self):
@@ -83,7 +83,7 @@ class TestWorkersService(TestService):
self._check_exists(test_worker, False, tags=["test"])
self._check_exists(test_worker, False, tags=["-application"])
def _simulate_workers(self) -> Sequence[str]:
def _simulate_workers(self, start: int) -> Sequence[str]:
"""
Two workers writing the same metrics. One for 4 seconds. Another one for 2
The first worker reports a task
@@ -105,25 +105,23 @@ class TestWorkersService(TestService):
(workers[0],),
(workers[0],),
]
timestamp = start * 1000
for ws, stats in zip(workers_activity, workers_stats):
for w, s in zip(ws, stats):
data = dict(
worker=w,
timestamp=int(utc_now_tz_aware().timestamp() * 1000),
timestamp=timestamp,
machine_stats=s,
)
if w == workers[0]:
data["task"] = task_id
self.api.workers.status_report(**data)
time.sleep(1)
timestamp += 1000
res = self.api.workers.get_all(last_seen=100)
return [w.key for w in res.workers]
return workers
def _create_running_task(self, task_name):
task_input = dict(
name=task_name, type="testing"
)
task_input = dict(name=task_name, type="testing")
task_id = self.create_temp("tasks", **task_input)
@@ -131,7 +129,8 @@ class TestWorkersService(TestService):
return task_id
def test_get_keys(self):
workers = self._simulate_workers()
workers = self._simulate_workers(int(time.time()))
time.sleep(5) # give to es time to refresh
res = self.api.workers.get_metric_keys(worker_ids=workers)
assert {"cpu", "memory"} == set(c.name for c in res["categories"])
assert all(
@@ -147,11 +146,12 @@ class TestWorkersService(TestService):
self.api.workers.get_metric_keys(worker_ids=["Non existing worker id"])
def test_get_stats(self):
workers = self._simulate_workers()
to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(days=1)
start = int(time.time())
workers = self._simulate_workers(start)
time.sleep(5) # give to ES time to refresh
from_date = start
to_date = start + 10
# no variants
res = self.api.workers.get_stats(
items=[
@@ -160,68 +160,58 @@ class TestWorkersService(TestService):
dict(key="memory_used", aggregation="max"),
dict(key="memory_used", aggregation="min"),
],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
from_date=from_date,
to_date=to_date,
# split_by_variant=True,
interval=1,
worker_ids=workers,
)
self.assertWorkersInStats(workers, res["workers"])
assert all(
{"cpu_usage", "memory_used"}
== set(map(attrgetter("metric"), worker["metrics"]))
for worker in res["workers"]
)
def _check_dates_and_stats(metric, stats, worker_id) -> bool:
return set(
map(attrgetter("aggregation"), metric["stats"])
) == stats and len(metric["dates"]) == (4 if worker_id == workers[0] else 2)
assert all(
_check_dates_and_stats(metric, metric_stats, worker["worker"])
for worker in res["workers"]
for metric, metric_stats in zip(
worker["metrics"], ({"avg", "max"}, {"max", "min"})
self.assertWorkersInStats(workers, res.workers)
for worker in res.workers:
self.assertEqual(
set(metric.metric for metric in worker.metrics),
{"cpu_usage", "memory_used"},
)
)
for worker in res.workers:
for metric, metric_stats in zip(
worker.metrics, ({"avg", "max"}, {"max", "min"})
):
self.assertEqual(
set(stat.aggregation for stat in metric.stats), metric_stats
)
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
# split by variants
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
from_date=from_date,
to_date=to_date,
split_by_variant=True,
interval=1,
worker_ids=workers,
)
self.assertWorkersInStats(workers, res["workers"])
self.assertWorkersInStats(workers, res.workers)
def _check_metric_and_variants(worker):
return (
all(
_check_dates_and_stats(metric, {"avg"}, worker["worker"])
for metric in worker["metrics"]
for worker in res.workers:
for metric in worker.metrics:
self.assertEqual(
set(metric.variant for metric in worker.metrics),
{"0", "1"} if worker.worker == workers[0] else {"0"},
)
and set(map(attrgetter("variant"), worker["metrics"])) == {"0", "1"}
if worker["worker"] == workers[0]
else {"0"}
)
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
from_date=from_date,
to_date=to_date,
interval=1,
worker_ids=["Non existing worker id"],
)
assert not res["workers"]
assert not res.workers
@staticmethod
def assertWorkersInStats(workers: Sequence[str], stats: dict):
assert set(workers) == set(map(attrgetter("worker"), stats))
def assertWorkersInStats(self, workers: Sequence[str], stats: Sequence):
self.assertEqual(set(workers), set(item.worker for item in stats))
def test_get_activity_report(self):
# test no workers data
@@ -232,28 +222,19 @@ class TestWorkersService(TestService):
# to_timestamp=to_timestamp.timestamp(),
# interval=20,
# )
start = int(time.time())
self._simulate_workers(int(time.time()))
self._simulate_workers()
to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(minutes=1)
time.sleep(5) # give to es time to refresh
# no variants
res = self.api.workers.get_activity_report(
from_date=from_date.timestamp(), to_date=to_date.timestamp(), interval=20
from_date=start, to_date=start + 10, interval=2
)
self.assertWorkerSeries(res["total"], 2)
self.assertWorkerSeries(res["active"], 1)
self.assertTotalSeriesGreaterThenActive(res["total"], res["active"])
self.assertWorkerSeries(res["total"], 2, 5)
self.assertWorkerSeries(res["active"], 1, 5)
@staticmethod
def assertTotalSeriesGreaterThenActive(total_data: dict, active_data: dict):
assert total_data["dates"][-1] == active_data["dates"][-1]
assert total_data["counts"][-1] > active_data["counts"][-1]
@staticmethod
def assertWorkerSeries(series_data: dict, min_count: int):
assert len(series_data["dates"]) == len(series_data["counts"])
# check the last 20s aggregation
# there may be more workers that we created since we are not filtering by test workers here
assert series_data["counts"][-1] >= min_count
def assertWorkerSeries(self, series_data: dict, count: int, size: int):
self.assertEqual(len(series_data["dates"]), size)
self.assertEqual(len(series_data["counts"]), size)
self.assertTrue(any(c == count for c in series_data["counts"]))
self.assertTrue(all(c <= count for c in series_data["counts"]))

View File

@@ -1,6 +1,8 @@
from boltons.dictutils import OneToOne
from mongoengine.queryset.transform import MATCH_OPERATORS
from apiserver.apierrors import errors
class ParameterKeyEscaper:
"""
@@ -15,8 +17,13 @@ class ParameterKeyEscaper:
@classmethod
def escape(cls, value: str):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
value = value.strip()
if not value:
raise errors.bad_request.ValidationError(
f"Empty key is not allowed"
)
value = value.replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)

View File

@@ -1 +1 @@
__version__ = "1.12.0"
__version__ = "1.13.0"

View File

@@ -155,6 +155,7 @@ services:
- http://fileserver:8081
volumes:
- c:/opt/clearml/logs:/var/log/clearml
- c:/opt/clearml/config:/opt/clearml/config
networks:
backend:

View File

@@ -154,6 +154,7 @@ services:
- http://fileserver:8081
volumes:
- /opt/clearml/logs:/var/log/clearml
- /opt/clearml/config:/opt/clearml/config
agent-services:
networks:

View File

@@ -5,3 +5,4 @@ flask>=2.3.2
gunicorn>=20.1.0
pyhocon>=0.3.35
setuptools>=65.5.1
urllib3>=1.26.18