diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index c6042f6..9a45183 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -39,7 +39,6 @@ from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.task import Task, TaskStatus from apiserver.redis_manager import redman -from apiserver.timing_context import TimingContext from apiserver.tools import safe_get from apiserver.utilities.dicts import nested_get from apiserver.utilities.json import loads @@ -97,7 +96,7 @@ class EventBLL(object): if not task_ids: return set() - with translate_errors_context(), TimingContext("mongo", "task_by_ids"): + with translate_errors_context(): query = Q(id__in=task_ids, company=company_id) if not allow_locked_tasks: query &= Q(status__nin=LOCKED_TASK_STATUSES) @@ -228,45 +227,44 @@ class EventBLL(object): with translate_errors_context(): if actions: chunk_size = 500 - with TimingContext("es", "events_add_batch"): - # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed - with closing( - elasticsearch.helpers.streaming_bulk( - self.es, - actions, - chunk_size=chunk_size, - # thread_count=8, - refresh=True, - ) - ) as it: - for success, info in it: - if success: - added += 1 - else: - errors_per_type["Error when indexing events batch"] += 1 + # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed + with closing( + elasticsearch.helpers.streaming_bulk( + self.es, + actions, + chunk_size=chunk_size, + # thread_count=8, + refresh=True, + ) + ) as it: + for success, info in it: + if success: + added += 1 + else: + errors_per_type["Error when indexing events batch"] += 1 - remaining_tasks = set() - now = datetime.utcnow() - for task_id in task_ids: - # Update related tasks. For reasons of performance, we prefer to update - # all of them and not only those who's events were successful - updated = self._update_task( - company_id=company_id, - task_id=task_id, - now=now, - iter_max=task_iteration.get(task_id), - last_scalar_events=task_last_scalar_events.get(task_id), - last_events=task_last_events.get(task_id), - ) + remaining_tasks = set() + now = datetime.utcnow() + for task_id in task_ids: + # Update related tasks. For reasons of performance, we prefer to update + # all of them and not only those who's events were successful + updated = self._update_task( + company_id=company_id, + task_id=task_id, + now=now, + iter_max=task_iteration.get(task_id), + last_scalar_events=task_last_scalar_events.get(task_id), + last_events=task_last_events.get(task_id), + ) - if not updated: - remaining_tasks.add(task_id) - continue + if not updated: + remaining_tasks.add(task_id) + continue - if remaining_tasks: - TaskBLL.set_last_update( - remaining_tasks, company_id, last_update=now - ) + if remaining_tasks: + TaskBLL.set_last_update( + remaining_tasks, company_id, last_update=now + ) # this is for backwards compatibility with streaming bulk throwing exception on those invalid_iterations_count = errors_per_type.get(invalid_iteration_error) @@ -425,7 +423,7 @@ class EventBLL(object): return [], scroll_id, 0 if scroll_id: - with translate_errors_context(), TimingContext("es", "task_log_events"): + with translate_errors_context(): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: size = min(batch_size, 10000) @@ -438,7 +436,7 @@ class EventBLL(object): "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, } - with translate_errors_context(), TimingContext("es", "scroll_task_events"): + with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, @@ -468,9 +466,7 @@ class EventBLL(object): if metric_variants: must.append(get_metric_variants_condition(metric_variants)) query = {"bool": {"must": must}} - search_args = dict( - es=self.es, company_id=company_id, event_type=event_type - ) + search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args, ) @@ -508,9 +504,7 @@ class EventBLL(object): "query": query, } - with translate_errors_context(), TimingContext( - "es", "task_last_iter_metric_variant" - ): + with translate_errors_context(): es_res = search_company_events(body=es_req, **search_args) if "aggregations" not in es_res: @@ -538,7 +532,7 @@ class EventBLL(object): return TaskEventsResult() if scroll_id: - with translate_errors_context(), TimingContext("es", "get_task_events"): + with translate_errors_context(): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: event_type = EventType.metrics_plot @@ -602,7 +596,7 @@ class EventBLL(object): "query": {"bool": {"must": must}}, } - with translate_errors_context(), TimingContext("es", "get_task_plots"): + with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, @@ -720,7 +714,7 @@ class EventBLL(object): return TaskEventsResult() if scroll_id: - with translate_errors_context(), TimingContext("es", "get_task_events"): + with translate_errors_context(): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: if check_empty_data(self.es, company_id=company_id, event_type=event_type): @@ -768,7 +762,7 @@ class EventBLL(object): "query": {"bool": {"must": must}}, } - with translate_errors_context(), TimingContext("es", "get_task_events"): + with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, @@ -793,9 +787,7 @@ class EventBLL(object): return {} query = {"bool": {"must": [{"term": {"task": task_id}}]}} - search_args = dict( - es=self.es, company_id=company_id, event_type=event_type - ) + search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args, ) @@ -822,9 +814,7 @@ class EventBLL(object): "query": query, } - with translate_errors_context(), TimingContext( - "es", "events_get_metrics_and_variants" - ): + with translate_errors_context(): es_res = search_company_events(body=es_req, **search_args) metrics = {} @@ -851,9 +841,7 @@ class EventBLL(object): ] } } - search_args = dict( - es=self.es, company_id=company_id, event_type=event_type - ) + search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args, ) @@ -899,9 +887,7 @@ class EventBLL(object): }, "_source": {"excludes": []}, } - with translate_errors_context(), TimingContext( - "es", "events_get_metrics_and_variants" - ): + with translate_errors_context(): es_res = search_company_events(body=es_req, **search_args) metrics = [] @@ -947,7 +933,7 @@ class EventBLL(object): "_source": ["iter", "value"], "sort": ["iter"], } - with translate_errors_context(), TimingContext("es", "task_stats_vector"): + with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, event_type=event_type, body=es_req ) @@ -990,7 +976,7 @@ class EventBLL(object): "query": {"bool": {"must": [{"terms": {"task": task_ids}}]}}, } - with translate_errors_context(), TimingContext("es", "task_last_iter"): + with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, event_type=event_type, body=es_req, ) @@ -1022,7 +1008,7 @@ class EventBLL(object): ) es_req = {"query": {"term": {"task": task_id}}} - with translate_errors_context(), TimingContext("es", "delete_task_events"): + with translate_errors_context(): es_res = delete_company_events( es=self.es, company_id=company_id, @@ -1048,7 +1034,7 @@ class EventBLL(object): ): return 0 - with translate_errors_context(), TimingContext("es", "clear_task_log"): + with translate_errors_context(): must = [{"term": {"task": task_id}}] sort = None if threshold_sec: @@ -1082,9 +1068,7 @@ class EventBLL(object): so it should be checked by the calling code """ es_req = {"query": {"terms": {"task": task_ids}}} - with translate_errors_context(), TimingContext( - "es", "delete_multi_tasks_events" - ): + with translate_errors_context(): es_res = delete_company_events( es=self.es, company_id=company_id, diff --git a/apiserver/bll/event/event_common.py b/apiserver/bll/event/event_common.py index f4b8b17..66849a2 100644 --- a/apiserver/bll/event/event_common.py +++ b/apiserver/bll/event/event_common.py @@ -8,7 +8,6 @@ from elasticsearch import Elasticsearch from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context -from apiserver.timing_context import TimingContext from apiserver.tools import safe_get @@ -21,7 +20,7 @@ class EventType(Enum): all = "*" -SINGLE_SCALAR_ITERATION = -2**31 +SINGLE_SCALAR_ITERATION = -(2 ** 31) MetricVariants = Mapping[str, Sequence[str]] @@ -80,9 +79,7 @@ def delete_company_events( es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs ) -> dict: es_index = get_index_name(company_id, event_type.value) - return es.delete_by_query( - index=es_index, body=body, conflicts="proceed", **kwargs - ) + return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs) def count_company_events( @@ -116,9 +113,7 @@ def get_max_metric_and_variant_counts( "query": query, "aggs": {"metrics_count": {"cardinality": {"field": "metric"}}}, } - with translate_errors_context(), TimingContext( - "es", "get_max_metric_and_variant_counts" - ): + with translate_errors_context(): es_res = search_company_events( es, company_id=company_id, event_type=event_type, body=es_req, **kwargs, ) diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index f8b4f26..7da4bc5 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -25,7 +25,6 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.task import Task -from apiserver.timing_context import TimingContext from apiserver.tools import safe_get log = config.logger(__file__) @@ -180,8 +179,7 @@ class EventMetrics: ): return {} - with TimingContext("es", "get_task_single_value_metrics"): - task_events = self._get_task_single_value_metrics(company_id, task_ids) + task_events = self._get_task_single_value_metrics(company_id, task_ids) def _get_value(event: dict): return { @@ -277,9 +275,7 @@ class EventMetrics: if metric_variants: must.append(get_metric_variants_condition(metric_variants)) query = {"bool": {"must": must}} - search_args = dict( - es=self.es, company_id=company_id, event_type=event_type - ) + search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args, ) @@ -312,8 +308,7 @@ class EventMetrics: }, } - with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): - es_res = search_company_events(body=es_req, **search_args) + es_res = search_company_events(body=es_req, **search_args) aggs_result = es_res.get("aggregations") if not aggs_result: @@ -366,9 +361,7 @@ class EventMetrics: interval, metrics = metrics_interval aggregation = self._add_aggregation_average(key.get_aggregation(interval)) query = self._get_task_metrics_query(task_id=task_id, metrics=metrics) - search_args = dict( - es=self.es, company_id=company_id, event_type=event_type - ) + search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args, ) @@ -493,10 +486,9 @@ class EventMetrics: }, } - with translate_errors_context(), TimingContext("es", "_get_task_metrics"): - es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req - ) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) return [ metric["key"] diff --git a/apiserver/bll/event/events_iterator.py b/apiserver/bll/event/events_iterator.py index 69bb616..291cddf 100644 --- a/apiserver/bll/event/events_iterator.py +++ b/apiserver/bll/event/events_iterator.py @@ -17,7 +17,6 @@ from apiserver.bll.event.event_common import ( from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context -from apiserver.timing_context import TimingContext @attr.s(auto_attribs=True) @@ -76,7 +75,7 @@ class EventsIterator: "query": query, } - with translate_errors_context(), TimingContext("es", "count_task_events"): + with translate_errors_context(): es_result = count_company_events( self.es, company_id=company_id, event_type=event_type, body=es_req, ) @@ -113,7 +112,7 @@ class EventsIterator: if from_key_value: es_req["search_after"] = [from_key_value] - with translate_errors_context(), TimingContext("es", "get_task_events"): + with translate_errors_context(): es_result = search_company_events( self.es, company_id=company_id, event_type=event_type, body=es_req, ) diff --git a/apiserver/bll/event/history_sample_iterator.py b/apiserver/bll/event/history_sample_iterator.py index cf90b37..ccf15d7 100644 --- a/apiserver/bll/event/history_sample_iterator.py +++ b/apiserver/bll/event/history_sample_iterator.py @@ -21,7 +21,6 @@ from apiserver.bll.event.event_common import ( ) from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.database.errors import translate_errors_context -from apiserver.timing_context import TimingContext from apiserver.utilities.dicts import nested_get @@ -174,14 +173,9 @@ class HistorySampleIterator(abc.ABC): "query": {"bool": {"must": must_conditions}}, } - with translate_errors_context(), TimingContext( - "es", "get_next_for_current_iteration" - ): + with translate_errors_context(): es_res = search_company_events( - self.es, - company_id=company_id, - event_type=self.event_type, - body=es_req, + self.es, company_id=company_id, event_type=self.event_type, body=es_req, ) hits = nested_get(es_res, ("hits", "hits")) @@ -235,14 +229,9 @@ class HistorySampleIterator(abc.ABC): "sort": [{"iter": order}, {"metric": order}, {"variant": order}], "query": {"bool": {"must": must_conditions}}, } - with translate_errors_context(), TimingContext( - "es", "get_next_for_another_iteration" - ): + with translate_errors_context(): es_res = search_company_events( - self.es, - company_id=company_id, - event_type=self.event_type, - body=es_req, + self.es, company_id=company_id, event_type=self.event_type, body=es_req, ) hits = nested_get(es_res, ("hits", "hits")) @@ -335,9 +324,7 @@ class HistorySampleIterator(abc.ABC): "query": {"bool": {"must": must_conditions}}, } - with translate_errors_context(), TimingContext( - "es", "get_history_sample_for_variant" - ): + with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, @@ -421,9 +408,7 @@ class HistorySampleIterator(abc.ABC): }, } - with translate_errors_context(), TimingContext( - "es", "get_history_sample_iterations" - ): + with translate_errors_context(): es_res = search_company_events(body=es_req, **search_args,) def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]: diff --git a/apiserver/bll/event/metric_events_iterator.py b/apiserver/bll/event/metric_events_iterator.py index 6b2e739..f7d3beb 100644 --- a/apiserver/bll/event/metric_events_iterator.py +++ b/apiserver/bll/event/metric_events_iterator.py @@ -19,14 +19,14 @@ from apiserver.bll.event.event_common import ( check_empty_data, search_company_events, EventType, - get_metric_variants_condition, get_max_metric_and_variant_counts, + get_metric_variants_condition, + get_max_metric_and_variant_counts, ) from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.metrics import MetricEventStats from apiserver.database.model.task.task import Task -from apiserver.timing_context import TimingContext class VariantState(Base): @@ -226,7 +226,9 @@ class MetricEventsIterator: pass @abc.abstractmethod - def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]: + def _get_variant_state_aggs( + self, + ) -> Tuple[dict, Callable[[dict, VariantState], None]]: pass def _init_metric_states_for_task( @@ -268,14 +270,18 @@ class MetricEventsIterator: "size": max_variants, "order": {"_key": "asc"}, }, - **({"aggs": variant_state_aggs} if variant_state_aggs else {}), + **( + {"aggs": variant_state_aggs} + if variant_state_aggs + else {} + ), }, }, } }, } - with translate_errors_context(), TimingContext("es", "_init_metric_states"): + with translate_errors_context(): es_res = search_company_events(body=es_req, **search_args) if "aggregations" not in es_res: return [] @@ -383,12 +389,9 @@ class MetricEventsIterator: } }, } - with translate_errors_context(), TimingContext("es", "_get_task_metric_events"): + with translate_errors_context(): es_res = search_company_events( - self.es, - company_id=company_id, - event_type=self.event_type, - body=es_req, + self.es, company_id=company_id, event_type=self.event_type, body=es_req, ) if "aggregations" not in es_res: return task_state.task, [] diff --git a/apiserver/bll/model/metadata.py b/apiserver/bll/model/metadata.py index 7994f6c..7c40576 100644 --- a/apiserver/bll/model/metadata.py +++ b/apiserver/bll/model/metadata.py @@ -11,7 +11,6 @@ from apiserver.utilities.parameter_key_escaper import ( mongoengine_safe, ) from apiserver.config_repo import config -from apiserver.timing_context import TimingContext log = config.logger(__file__) @@ -42,27 +41,25 @@ class Metadata: replace_metadata: bool, **more_updates, ) -> int: - with TimingContext("mongo", "edit_metadata"): - update_cmds = dict() - metadata = cls.metadata_from_api(items) - if replace_metadata: - update_cmds["set__metadata"] = metadata - else: - for key, value in metadata.items(): - update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value + update_cmds = dict() + metadata = cls.metadata_from_api(items) + if replace_metadata: + update_cmds["set__metadata"] = metadata + else: + for key, value in metadata.items(): + update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value - return obj.update(**update_cmds, **more_updates) + return obj.update(**update_cmds, **more_updates) @classmethod def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int: - with TimingContext("mongo", "delete_metadata"): - return obj.update( - **{ - f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1 - for key in set(keys) - }, - **more_updates, - ) + return obj.update( + **{ + f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1 + for key in set(keys) + }, + **more_updates, + ) @staticmethod def _process_path(path: str): diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index d8e71e6..b1d037f 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -29,7 +29,6 @@ from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskStatus, external_task_types from apiserver.database.utils import get_options, get_company_or_none_constraint -from apiserver.timing_context import TimingContext from apiserver.utilities.dicts import nested_get from .sub_projects import ( _reposition_project_with_children, @@ -57,54 +56,53 @@ class ProjectBLL: Remove the source project Return the amounts of moved entities and subprojects + set of all the affected project ids """ - with TimingContext("mongo", "move_project"): - if source_id == destination_id: - raise errors.bad_request.ProjectSourceAndDestinationAreTheSame( - source=source_id + if source_id == destination_id: + raise errors.bad_request.ProjectSourceAndDestinationAreTheSame( + source=source_id + ) + source = Project.get(company, source_id) + if destination_id: + destination = Project.get(company, destination_id) + if source_id in destination.path: + raise errors.bad_request.ProjectCannotBeMergedIntoItsChild( + source=source_id, destination=destination_id ) - source = Project.get(company, source_id) - if destination_id: - destination = Project.get(company, destination_id) - if source_id in destination.path: - raise errors.bad_request.ProjectCannotBeMergedIntoItsChild( - source=source_id, destination=destination_id - ) - else: - destination = None + else: + destination = None - children = _get_sub_projects( - [source.id], _only=("id", "name", "parent", "path") - )[source.id] - if destination: - cls.validate_projects_depth( - projects=children, - old_parent_depth=len(source.path) + 1, - new_parent_depth=len(destination.path) + 1, - ) + children = _get_sub_projects( + [source.id], _only=("id", "name", "parent", "path") + )[source.id] + if destination: + cls.validate_projects_depth( + projects=children, + old_parent_depth=len(source.path) + 1, + new_parent_depth=len(destination.path) + 1, + ) - moved_entities = 0 - for entity_type in (Task, Model): - moved_entities += entity_type.objects( - company=company, - project=source_id, - system_tags__nin=[EntityVisibility.archived.value], - ).update(upsert=False, project=destination_id) + moved_entities = 0 + for entity_type in (Task, Model): + moved_entities += entity_type.objects( + company=company, + project=source_id, + system_tags__nin=[EntityVisibility.archived.value], + ).update(upsert=False, project=destination_id) - moved_sub_projects = 0 - for child in Project.objects(company=company, parent=source_id): - _reposition_project_with_children( - project=child, - children=[c for c in children if c.parent == child.id], - parent=destination, - ) - moved_sub_projects += 1 + moved_sub_projects = 0 + for child in Project.objects(company=company, parent=source_id): + _reposition_project_with_children( + project=child, + children=[c for c in children if c.parent == child.id], + parent=destination, + ) + moved_sub_projects += 1 - affected = {source.id, *(source.path or [])} - source.delete() + affected = {source.id, *(source.path or [])} + source.delete() - if destination: - destination.update(last_update=datetime.utcnow()) - affected.update({destination.id, *(destination.path or [])}) + if destination: + destination.update(last_update=datetime.utcnow()) + affected.update({destination.id, *(destination.path or [])}) return moved_entities, moved_sub_projects, affected @@ -127,78 +125,76 @@ class ProjectBLL: it should be writable. The source location should be writable too. Return the number of moved projects + set of all the affected project ids """ - with TimingContext("mongo", "move_project"): - project = Project.get(company, project_id) - old_parent_id = project.parent - old_parent = ( - Project.get_for_writing(company=project.company, id=old_parent_id) - if old_parent_id - else None + project = Project.get(company, project_id) + old_parent_id = project.parent + old_parent = ( + Project.get_for_writing(company=project.company, id=old_parent_id) + if old_parent_id + else None + ) + + children = _get_sub_projects([project.id], _only=("id", "name", "path"))[ + project.id + ] + cls.validate_projects_depth( + projects=[project, *children], + old_parent_depth=len(project.path), + new_parent_depth=_get_project_depth(new_location), + ) + + new_parent = _ensure_project(company=company, user=user, name=new_location) + new_parent_id = new_parent.id if new_parent else None + if old_parent_id == new_parent_id: + raise errors.bad_request.ProjectSourceAndDestinationAreTheSame( + location=new_parent.name if new_parent else "" ) - - children = _get_sub_projects([project.id], _only=("id", "name", "path"))[ - project.id - ] - cls.validate_projects_depth( - projects=[project, *children], - old_parent_depth=len(project.path), - new_parent_depth=_get_project_depth(new_location), + if new_parent and ( + project_id == new_parent.id or project_id in new_parent.path + ): + raise errors.bad_request.ProjectCannotBeMovedUnderItself( + project=project_id, parent=new_parent.id ) + moved = _reposition_project_with_children( + project, children=children, parent=new_parent + ) - new_parent = _ensure_project(company=company, user=user, name=new_location) - new_parent_id = new_parent.id if new_parent else None - if old_parent_id == new_parent_id: - raise errors.bad_request.ProjectSourceAndDestinationAreTheSame( - location=new_parent.name if new_parent else "" - ) - if new_parent and ( - project_id == new_parent.id or project_id in new_parent.path - ): - raise errors.bad_request.ProjectCannotBeMovedUnderItself( - project=project_id, parent=new_parent.id - ) - moved = _reposition_project_with_children( - project, children=children, parent=new_parent - ) + now = datetime.utcnow() + affected = set() + for p in filter(None, (old_parent, new_parent)): + p.update(last_update=now) + affected.update({p.id, *(p.path or [])}) - now = datetime.utcnow() - affected = set() - for p in filter(None, (old_parent, new_parent)): - p.update(last_update=now) - affected.update({p.id, *(p.path or [])}) - - return moved, affected + return moved, affected @classmethod def update(cls, company: str, project_id: str, **fields): - with TimingContext("mongo", "projects_update"): - project = Project.get_for_writing(company=company, id=project_id) - if not project: - raise errors.bad_request.InvalidProjectId(id=project_id) + project = Project.get_for_writing(company=company, id=project_id) + if not project: + raise errors.bad_request.InvalidProjectId(id=project_id) - new_name = fields.pop("name", None) - if new_name: - new_name, new_location = _validate_project_name(new_name) - old_name, old_location = _validate_project_name(project.name) - if new_location != old_location: - raise errors.bad_request.CannotUpdateProjectLocation(name=new_name) - fields["name"] = new_name - fields["basename"] = new_name.split("/")[-1] + new_name = fields.pop("name", None) + if new_name: + new_name, new_location = _validate_project_name(new_name) + old_name, old_location = _validate_project_name(project.name) + if new_location != old_location: + raise errors.bad_request.CannotUpdateProjectLocation(name=new_name) + fields["name"] = new_name + fields["basename"] = new_name.split("/")[-1] - fields["last_update"] = datetime.utcnow() - updated = project.update(upsert=False, **fields) + fields["last_update"] = datetime.utcnow() + updated = project.update(upsert=False, **fields) - if new_name: - old_name = project.name - project.name = new_name - children = _get_sub_projects( - [project.id], _only=("id", "name", "path") - )[project.id] - _update_subproject_names( - project=project, children=children, old_name=old_name - ) + if new_name: + old_name = project.name + project.name = new_name + children = _get_sub_projects([project.id], _only=("id", "name", "path"))[ + project.id + ] + _update_subproject_names( + project=project, children=children, old_name=old_name + ) - return updated + return updated @classmethod def create( @@ -301,24 +297,23 @@ class ProjectBLL: """ Move a batch of entities to `project` or a project named `project_name` (create if does not exist) """ - with TimingContext("mongo", "move_under_project"): - project = cls.find_or_create( - user=user, - company=company, - project_id=project, - project_name=project_name, - description="", - ) - extra = ( - {"set__last_change": datetime.utcnow()} - if hasattr(entity_cls, "last_change") - else {} - ) - entity_cls.objects(company=company, id__in=ids).update( - set__project=project, **extra - ) + project = cls.find_or_create( + user=user, + company=company, + project_id=project, + project_name=project_name, + description="", + ) + extra = ( + {"set__last_change": datetime.utcnow()} + if hasattr(entity_cls, "last_change") + else {} + ) + entity_cls.objects(company=company, id__in=ids).update( + set__project=project, **extra + ) - return project + return project archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]} visibility_states = [EntityVisibility.archived, EntityVisibility.active] @@ -716,22 +711,21 @@ class ProjectBLL: If project_ids is empty then all projects are examined If user_ids are passed then only subset of these users is returned """ - with TimingContext("mongo", "active_users_in_projects"): - query = Q(company=company) - if user_ids: - query &= Q(user__in=user_ids) + query = Q(company=company) + if user_ids: + query &= Q(user__in=user_ids) - projects_query = query - if project_ids: - project_ids = _ids_with_children(project_ids) - query &= Q(project__in=project_ids) - projects_query &= Q(id__in=project_ids) + projects_query = query + if project_ids: + project_ids = _ids_with_children(project_ids) + query &= Q(project__in=project_ids) + projects_query &= Q(id__in=project_ids) - res = set(Project.objects(projects_query).distinct(field="user")) - for cls_ in (Task, Model): - res |= set(cls_.objects(query).distinct(field="user")) + res = set(Project.objects(projects_query).distinct(field="user")) + for cls_ in (Task, Model): + res |= set(cls_.objects(query).distinct(field="user")) - return res + return res @classmethod def get_project_tags( @@ -741,21 +735,20 @@ class ProjectBLL: projects: Sequence[str] = None, filter_: Dict[str, Sequence[str]] = None, ) -> Tuple[Sequence[str], Sequence[str]]: - with TimingContext("mongo", "get_tags_from_db"): - query = Q(company=company_id) - if filter_: - for name, vals in filter_.items(): - if vals: - query &= GetMixin.get_list_field_query(name, vals) + query = Q(company=company_id) + if filter_: + for name, vals in filter_.items(): + if vals: + query &= GetMixin.get_list_field_query(name, vals) - if projects: - query &= Q(id__in=_ids_with_children(projects)) + if projects: + query &= Q(id__in=_ids_with_children(projects)) - tags = Project.objects(query).distinct("tags") - system_tags = ( - Project.objects(query).distinct("system_tags") if include_system else [] - ) - return tags, system_tags + tags = Project.objects(query).distinct("tags") + system_tags = ( + Project.objects(query).distinct("system_tags") if include_system else [] + ) + return tags, system_tags @classmethod def get_projects_with_active_user( @@ -930,10 +923,9 @@ class ProjectBLL: def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict: return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)} - with TimingContext("mongo", "get_security_groups"): - tasks = get_agrregate_res(Task) - models = get_agrregate_res(Model) - return { - pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)} - for pid in project_ids - } + tasks = get_agrregate_res(Task) + models = get_agrregate_res(Model) + return { + pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)} + for pid in project_ids + } diff --git a/apiserver/bll/project/project_cleanup.py b/apiserver/bll/project/project_cleanup.py index 2151426..9d36194 100644 --- a/apiserver/bll/project/project_cleanup.py +++ b/apiserver/bll/project/project_cleanup.py @@ -14,7 +14,6 @@ from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, ArtifactModes, TaskType -from apiserver.timing_context import TimingContext from .sub_projects import _ids_with_children log = config.logger(__file__) @@ -88,11 +87,8 @@ def delete_project( ) if not delete_contents: - with TimingContext("mongo", "update_children"): - for cls in (Model, Task): - updated_count = cls.objects(project__in=project_ids).update( - project=None - ) + for cls in (Model, Task): + updated_count = cls.objects(project__in=project_ids).update(project=None) res = DeleteProjectResult(disassociated_tasks=updated_count) else: deleted_models, model_urls = _delete_models(projects=project_ids) @@ -127,9 +123,8 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set] return 0, set(), set() task_ids = {t.id for t in tasks} - with TimingContext("mongo", "delete_tasks_update_children"): - Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None) - Model.objects(task__in=task_ids, project__nin=projects).update(task=None) + Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None) + Model.objects(task__in=task_ids, project__nin=projects).update(task=None) event_urls, artifact_urls = set(), set() for task in tasks: @@ -154,36 +149,35 @@ def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]: Delete project models and update the tasks from other projects that reference them to reference None. """ - with TimingContext("mongo", "delete_models"): - models = Model.objects(project__in=projects).only("task", "id", "uri") - if not models: - return 0, set() + models = Model.objects(project__in=projects).only("task", "id", "uri") + if not models: + return 0, set() - model_ids = list({m.id for m in models}) + model_ids = list({m.id for m in models}) + Task._get_collection().update_many( + filter={ + "project": {"$nin": projects}, + "models.input.model": {"$in": model_ids}, + }, + update={"$set": {"models.input.$[elem].model": None}}, + array_filters=[{"elem.model": {"$in": model_ids}}], + upsert=False, + ) + + model_tasks = list({m.task for m in models if m.task}) + if model_tasks: Task._get_collection().update_many( filter={ + "_id": {"$in": model_tasks}, "project": {"$nin": projects}, - "models.input.model": {"$in": model_ids}, + "models.output.model": {"$in": model_ids}, }, - update={"$set": {"models.input.$[elem].model": None}}, + update={"$set": {"models.output.$[elem].model": None}}, array_filters=[{"elem.model": {"$in": model_ids}}], upsert=False, ) - model_tasks = list({m.task for m in models if m.task}) - if model_tasks: - Task._get_collection().update_many( - filter={ - "_id": {"$in": model_tasks}, - "project": {"$nin": projects}, - "models.output.model": {"$in": model_ids}, - }, - update={"$set": {"models.output.$[elem].model": None}}, - array_filters=[{"elem.model": {"$in": model_ids}}], - upsert=False, - ) - - urls = {m.uri for m in models if m.uri} - deleted = models.delete() - return deleted, urls + urls = {m.uri for m in models if m.uri} + deleted = models.delete() + return deleted, urls diff --git a/apiserver/bll/queue/queue_metrics.py b/apiserver/bll/queue/queue_metrics.py index 3462526..ba5cdaf 100644 --- a/apiserver/bll/queue/queue_metrics.py +++ b/apiserver/bll/queue/queue_metrics.py @@ -14,7 +14,6 @@ from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.queue import Queue, Entry from apiserver.redis_manager import redman -from apiserver.timing_context import TimingContext from apiserver.utilities.threads_manager import ThreadsManager log = config.logger(__file__) @@ -182,7 +181,7 @@ class QueueMetrics: "aggs": self._get_dates_agg(interval), } - with translate_errors_context(), TimingContext("es", "get_queue_metrics"): + with translate_errors_context(): res = self._search_company_metrics(company_id, es_req) if "aggregations" not in res: @@ -285,6 +284,7 @@ class MetricsRefresher: if not queue_metrics: from .queue_bll import QueueBLL + queue_metrics = QueueBLL().metrics sleep(10) diff --git a/apiserver/bll/redis_cache_manager.py b/apiserver/bll/redis_cache_manager.py index 0f1b651..131893d 100644 --- a/apiserver/bll/redis_cache_manager.py +++ b/apiserver/bll/redis_cache_manager.py @@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable from redis import StrictRedis from apiserver import database -from apiserver.timing_context import TimingContext T = TypeVar("T") @@ -31,20 +30,17 @@ class RedisCacheManager(Generic[T]): def set_state(self, state: T) -> None: redis_key = self._get_redis_key(state.id) - with TimingContext("redis", "cache_set_state"): - self.redis.set(redis_key, state.to_json()) - self.redis.expire(redis_key, self.expiration_interval) + self.redis.set(redis_key, state.to_json()) + self.redis.expire(redis_key, self.expiration_interval) def get_state(self, state_id) -> Optional[T]: redis_key = self._get_redis_key(state_id) - with TimingContext("redis", "cache_get_state"): - response = self.redis.get(redis_key) + response = self.redis.get(redis_key) if response: return self.state_class.from_json(response) def delete_state(self, state_id) -> None: - with TimingContext("redis", "cache_delete_state"): - self.redis.delete(self._get_redis_key(state_id)) + self.redis.delete(self._get_redis_key(state_id)) def _get_redis_key(self, state_id): return f"{self.state_class}/{state_id}" diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index b44e428..0305f5e 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -5,7 +5,6 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact from apiserver.database.utils import hash_field_name -from apiserver.timing_context import TimingContext from apiserver.utilities.dicts import nested_get, nested_set from apiserver.utilities.parameter_key_escaper import mongoengine_safe @@ -53,23 +52,18 @@ class Artifacts: artifacts: Sequence[ApiArtifact], force: bool, ) -> int: - with TimingContext("mongo", "update_artifacts"): - task = get_task_for_update( - company_id=company_id, - task_id=task_id, - force=force, - ) + task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,) - artifacts = { - get_artifact_id(a): Artifact(**a) - for a in (api_artifact.to_struct() for api_artifact in artifacts) - } + artifacts = { + get_artifact_id(a): Artifact(**a) + for a in (api_artifact.to_struct() for api_artifact in artifacts) + } - update_cmds = { - f"set__execution__artifacts__{mongoengine_safe(name)}": value - for name, value in artifacts.items() - } - return update_task(task, update_cmds=update_cmds) + update_cmds = { + f"set__execution__artifacts__{mongoengine_safe(name)}": value + for name, value in artifacts.items() + } + return update_task(task, update_cmds=update_cmds) @classmethod def delete_artifacts( @@ -79,19 +73,14 @@ class Artifacts: artifact_ids: Sequence[ArtifactId], force: bool, ) -> int: - with TimingContext("mongo", "delete_artifacts"): - task = get_task_for_update( - company_id=company_id, - task_id=task_id, - force=force, - ) + task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,) - artifact_ids = [ - get_artifact_id(a) - for a in (artifact_id.to_struct() for artifact_id in artifact_ids) - ] - delete_cmds = { - f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) - } + artifact_ids = [ + get_artifact_id(a) + for a in (artifact_id.to_struct() for artifact_id in artifact_ids) + ] + delete_cmds = { + f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) + } - return update_task(task, update_cmds=delete_cmds) + return update_task(task, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index d61cc7b..f8d66c9 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -15,7 +15,6 @@ from apiserver.bll.task import TaskBLL from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.config_repo import config from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem -from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ( ParameterKeyEscaper, mongoengine_safe, @@ -68,36 +67,35 @@ class HyperParams: hyperparams: Sequence[HyperParamKey], force: bool, ) -> int: - with TimingContext("mongo", "delete_hyperparams"): - properties_only = cls._normalize_params(hyperparams) - task = get_task_for_update( - company_id=company_id, - task_id=task_id, - allow_all_statuses=properties_only, - force=force, - ) + properties_only = cls._normalize_params(hyperparams) + task = get_task_for_update( + company_id=company_id, + task_id=task_id, + allow_all_statuses=properties_only, + force=force, + ) - with_param, without_param = iterutils.partition( - hyperparams, key=lambda p: bool(p.name) - ) - sections_to_delete = {p.section for p in without_param} - delete_cmds = { - f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1 - for section in sections_to_delete - } + with_param, without_param = iterutils.partition( + hyperparams, key=lambda p: bool(p.name) + ) + sections_to_delete = {p.section for p in without_param} + delete_cmds = { + f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1 + for section in sections_to_delete + } - for item in with_param: - section = ParameterKeyEscaper.escape(item.section) - if item.section in sections_to_delete: - raise errors.bad_request.FieldsConflict( - "Cannot delete section field if the whole section was scheduled for deletion" - ) - name = ParameterKeyEscaper.escape(item.name) - delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 + for item in with_param: + section = ParameterKeyEscaper.escape(item.section) + if item.section in sections_to_delete: + raise errors.bad_request.FieldsConflict( + "Cannot delete section field if the whole section was scheduled for deletion" + ) + name = ParameterKeyEscaper.escape(item.name) + delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 - return update_task( - task, update_cmds=delete_cmds, set_last_update=not properties_only - ) + return update_task( + task, update_cmds=delete_cmds, set_last_update=not properties_only + ) @classmethod def edit_params( @@ -108,34 +106,31 @@ class HyperParams: replace_hyperparams: str, force: bool, ) -> int: - with TimingContext("mongo", "edit_hyperparams"): - properties_only = cls._normalize_params(hyperparams) - task = get_task_for_update( - company_id=company_id, - task_id=task_id, - allow_all_statuses=properties_only, - force=force, - ) + properties_only = cls._normalize_params(hyperparams) + task = get_task_for_update( + company_id=company_id, + task_id=task_id, + allow_all_statuses=properties_only, + force=force, + ) - update_cmds = dict() - hyperparams = cls._db_dicts_from_list(hyperparams) - if replace_hyperparams == ReplaceHyperparams.all: - update_cmds["set__hyperparams"] = hyperparams - elif replace_hyperparams == ReplaceHyperparams.section: - for section, value in hyperparams.items(): + update_cmds = dict() + hyperparams = cls._db_dicts_from_list(hyperparams) + if replace_hyperparams == ReplaceHyperparams.all: + update_cmds["set__hyperparams"] = hyperparams + elif replace_hyperparams == ReplaceHyperparams.section: + for section, value in hyperparams.items(): + update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value + else: + for section, section_params in hyperparams.items(): + for name, value in section_params.items(): update_cmds[ - f"set__hyperparams__{mongoengine_safe(section)}" + f"set__hyperparams__{section}__{mongoengine_safe(name)}" ] = value - else: - for section, section_params in hyperparams.items(): - for name, value in section_params.items(): - update_cmds[ - f"set__hyperparams__{section}__{mongoengine_safe(name)}" - ] = value - return update_task( - task, update_cmds=update_cmds, set_last_update=not properties_only - ) + return update_task( + task, update_cmds=update_cmds, set_last_update=not properties_only + ) @classmethod def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]: @@ -191,17 +186,16 @@ class HyperParams: {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, ] - with TimingContext("mongo", "get_configuration_names"): - tasks = Task.aggregate(pipeline) + tasks = Task.aggregate(pipeline) - return { - task["_id"]: { - "names": sorted( - ParameterKeyEscaper.unescape(name) for name in task["names"] - ) - } - for task in tasks + return { + task["_id"]: { + "names": sorted( + ParameterKeyEscaper.unescape(name) for name in task["names"] + ) } + for task in tasks + } @classmethod def edit_configuration( @@ -212,36 +206,30 @@ class HyperParams: replace_configuration: bool, force: bool, ) -> int: - with TimingContext("mongo", "edit_configuration"): - task = get_task_for_update( - company_id=company_id, task_id=task_id, force=force - ) + task = get_task_for_update(company_id=company_id, task_id=task_id, force=force) - update_cmds = dict() - configuration = { - ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct()) - for c in configuration - } - if replace_configuration: - update_cmds["set__configuration"] = configuration - else: - for name, value in configuration.items(): - update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value + update_cmds = dict() + configuration = { + ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct()) + for c in configuration + } + if replace_configuration: + update_cmds["set__configuration"] = configuration + else: + for name, value in configuration.items(): + update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value - return update_task(task, update_cmds=update_cmds) + return update_task(task, update_cmds=update_cmds) @classmethod def delete_configuration( cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool ) -> int: - with TimingContext("mongo", "delete_configuration"): - task = get_task_for_update( - company_id=company_id, task_id=task_id, force=force - ) + task = get_task_for_update(company_id=company_id, task_id=task_id, force=force) - delete_cmds = { - f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 - for name in set(configuration) - } + delete_cmds = { + f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 + for name in set(configuration) + } - return update_task(task, update_cmds=delete_cmds) + return update_task(task, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 2a22a06..ae181ab 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -35,7 +35,6 @@ from apiserver.es_factory import es_factory from apiserver.redis_manager import redman from apiserver.service_repo import APICall from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict -from apiserver.timing_context import TimingContext from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save from .utils import ( @@ -66,11 +65,10 @@ class TaskBLL: """ with translate_errors_context(): query = dict(id=task_id, company=company_id) - with TimingContext("mongo", "task_with_access"): - if requires_write_access: - task = Task.get_for_writing(_only=only, **query) - else: - task = Task.get(_only=only, **query, include_public=allow_public) + if requires_write_access: + task = Task.get_for_writing(_only=only, **query) + else: + task = Task.get(_only=only, **query, include_public=allow_public) if not task: raise errors.bad_request.InvalidTaskId(**query) @@ -88,15 +86,14 @@ class TaskBLL: only_fields = list(only_fields) only_fields = only_fields + ["status"] - with TimingContext("mongo", "task_by_id_all"): - tasks = Task.get_many( - company=company_id, - query=Q(id=task_id), - allow_public=allow_public, - override_projection=only_fields, - return_dicts=False, - ) - task = None if not tasks else tasks[0] + tasks = Task.get_many( + company=company_id, + query=Q(id=task_id), + allow_public=allow_public, + override_projection=only_fields, + return_dicts=False, + ) + task = None if not tasks else tasks[0] if not task: raise errors.bad_request.InvalidTaskId(id=task_id) @@ -111,7 +108,7 @@ class TaskBLL: company_id, task_ids, only=None, allow_public=False, return_tasks=True ) -> Optional[Sequence[Task]]: task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids - with translate_errors_context(), TimingContext("mongo", "task_exists"): + with translate_errors_context(): ids = set(task_ids) q = Task.get_many( company=company_id, @@ -260,58 +257,55 @@ class TaskBLL: not in [TaskSystemTags.development, EntityVisibility.archived.value] ] - with TimingContext("mongo", "clone task"): - parent_task = ( - task.parent - if task.parent and not task.parent.startswith(deleted_prefix) - else task.id - ) - new_task = Task( - id=create_id(), - user=user_id, - company=company_id, - created=now, - last_update=now, - last_change=now, - name=name or task.name, - comment=comment or task.comment, - parent=parent or parent_task, - project=project or task.project, - tags=tags or task.tags, - system_tags=system_tags or clean_system_tags(task.system_tags), - type=task.type, - script=task.script, - output=Output(destination=task.output.destination) - if task.output - else None, - models=Models(input=input_models or task.models.input), - container=escape_dict(container) or task.container, - execution=execution_dict, - configuration=params_dict.get("configuration") or task.configuration, - hyperparams=params_dict.get("hyperparams") or task.hyperparams, - ) - cls.validate( - new_task, - validate_models=validate_references or input_models, - validate_parent=validate_references or parent, - validate_project=validate_references or project, - ) - new_task.save() + parent_task = ( + task.parent + if task.parent and not task.parent.startswith(deleted_prefix) + else task.id + ) + new_task = Task( + id=create_id(), + user=user_id, + company=company_id, + created=now, + last_update=now, + last_change=now, + name=name or task.name, + comment=comment or task.comment, + parent=parent or parent_task, + project=project or task.project, + tags=tags or task.tags, + system_tags=system_tags or clean_system_tags(task.system_tags), + type=task.type, + script=task.script, + output=Output(destination=task.output.destination) if task.output else None, + models=Models(input=input_models or task.models.input), + container=escape_dict(container) or task.container, + execution=execution_dict, + configuration=params_dict.get("configuration") or task.configuration, + hyperparams=params_dict.get("hyperparams") or task.hyperparams, + ) + cls.validate( + new_task, + validate_models=validate_references or input_models, + validate_parent=validate_references or parent, + validate_project=validate_references or project, + ) + new_task.save() - if task.project == new_task.project: - updated_tags = tags - updated_system_tags = system_tags - else: - updated_tags = new_task.tags - updated_system_tags = new_task.system_tags - org_bll.update_tags( - company_id, - Tags.Task, - project=new_task.project, - tags=updated_tags, - system_tags=updated_system_tags, - ) - update_project_time(new_task.project) + if task.project == new_task.project: + updated_tags = tags + updated_system_tags = system_tags + else: + updated_tags = new_task.tags + updated_system_tags = new_task.system_tags + org_bll.update_tags( + company_id, + Tags.Task, + project=new_task.project, + tags=updated_tags, + system_tags=updated_system_tags, + ) + update_project_time(new_task.project) return new_task, new_project_data diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py index c64de14..bdb05d5 100644 --- a/apiserver/bll/task/task_cleanup.py +++ b/apiserver/bll/task/task_cleanup.py @@ -15,8 +15,12 @@ from apiserver.bll.task.utils import deleted_prefix from apiserver.config_repo import config from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes -from apiserver.database.model.url_to_delete import StorageType, UrlToDelete, FileType, DeletionStatus -from apiserver.timing_context import TimingContext +from apiserver.database.model.url_to_delete import ( + StorageType, + UrlToDelete, + FileType, + DeletionStatus, +) from apiserver.database.utils import id as db_id event_bll = EventBLL() @@ -65,17 +69,16 @@ class CleanupResult: def collect_plot_image_urls(company: str, task: str) -> Set[str]: urls = set() next_scroll_id = None - with TimingContext("es", "collect_plot_image_urls"): - while True: - events, next_scroll_id = event_bll.get_plot_image_urls( - company_id=company, task_id=task, scroll_id=next_scroll_id - ) - if not events: - break - for event in events: - event_urls = event.get(PlotFields.source_urls) - if event_urls: - urls.update(set(event_urls)) + while True: + events, next_scroll_id = event_bll.get_plot_image_urls( + company_id=company, task_id=task, scroll_id=next_scroll_id + ) + if not events: + break + for event in events: + event_urls = event.get(PlotFields.source_urls) + if event_urls: + urls.update(set(event_urls)) return urls @@ -89,9 +92,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]: urls = set() while True: res, after_key = event_bll.get_debug_image_urls( - company_id=company, - task_id=task, - after_key=after_key, + company_id=company, task_id=task, after_key=after_key, ) urls.update(res) if not after_key: @@ -198,10 +199,7 @@ def cleanup_task( deleted_task_id = f"{deleted_prefix}{task.id}" updated_children = 0 if update_children: - with TimingContext("mongo", "update_task_children"): - updated_children = Task.objects(parent=task.id).update( - parent=deleted_task_id - ) + updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id) deleted_models = 0 updated_models = 0 @@ -256,16 +254,15 @@ def verify_task_children_and_ouptuts( task, force: bool ) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]: if not force: - with TimingContext("mongo", "count_published_children"): - published_children_count = Task.objects( - parent=task.id, status=TaskStatus.published - ).count() - if published_children_count: - raise errors.bad_request.TaskCannotBeDeleted( - "has children, use force=True", - task=task.id, - children=published_children_count, - ) + published_children_count = Task.objects( + parent=task.id, status=TaskStatus.published + ).count() + if published_children_count: + raise errors.bad_request.TaskCannotBeDeleted( + "has children, use force=True", + task=task.id, + children=published_children_count, + ) model_fields = ["id", "ready", "uri"] published_models, draft_models = partition( diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index bb18be9..070309d 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -9,7 +9,6 @@ from apiserver.database.errors import translate_errors_context from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags from apiserver.database.utils import get_options -from apiserver.timing_context import TimingContext from apiserver.utilities.attrs import typed_attrs valid_statuses = get_options(TaskStatus) @@ -55,7 +54,7 @@ class ChangeStatusRequest(object): fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()}) - with translate_errors_context(), TimingContext("mongo", "task_status"): + with translate_errors_context(): # atomic change of task status by querying the task with the EXPECTED status before modifying it params = fields.copy() params.update(control) diff --git a/apiserver/bll/workers/__init__.py b/apiserver/bll/workers/__init__.py index fa00de5..28db111 100644 --- a/apiserver/bll/workers/__init__.py +++ b/apiserver/bll/workers/__init__.py @@ -25,7 +25,6 @@ from apiserver.database.model.project import Project from apiserver.database.model.queue import Queue from apiserver.database.model.task.task import Task from apiserver.redis_manager import redman -from apiserver.timing_context import TimingContext from apiserver.tools import safe_get from .stats import WorkerStats @@ -109,15 +108,19 @@ class WorkerBLL: :param worker: worker ID :raise bad_request.WorkerNotRegistered: the worker was not previously registered """ - with TimingContext("redis", "workers_unregister"): - res = self.redis.delete( - company_id, self._get_worker_key(company_id, user_id, worker) - ) + res = self.redis.delete( + company_id, self._get_worker_key(company_id, user_id, worker) + ) if not res and not config.get("apiserver.workers.auto_unregister", False): raise bad_request.WorkerNotRegistered(worker=worker) def status_report( - self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None, + self, + company_id: str, + user_id: str, + ip: str, + report: StatusReportRequest, + tags: Sequence[str] = None, ) -> None: """ Write worker status report @@ -176,7 +179,9 @@ class WorkerBLL: if task.project: project = Project.objects(id=task.project).only("name").first() if project: - entry.project = IdNameEntry(id=project.id, name=project.name) + entry.project = IdNameEntry( + id=project.id, name=project.name + ) entry.last_report_time = now except APIError: @@ -323,8 +328,7 @@ class WorkerBLL: """ key = self._get_worker_key(company_id, user_id, worker) - with TimingContext("redis", "get_worker"): - data = self.redis.get(key) + data = self.redis.get(key) if data: try: @@ -367,11 +371,10 @@ class WorkerBLL: """Get worker entries matching the company and user, worker patterns""" entries = [] match = self._get_worker_key(company, user, worker_id) - with TimingContext("redis", "workers_get_all"): - for r in self.redis.scan_iter(match): - data = self.redis.get(r) - if data: - entries.append(WorkerEntry.from_json(data)) + for r in self.redis.scan_iter(match): + data = self.redis.get(r) + if data: + entries.append(WorkerEntry.from_json(data)) return entries diff --git a/apiserver/bll/workers/stats.py b/apiserver/bll/workers/stats.py index 1a9d8af..eab81a2 100644 --- a/apiserver/bll/workers/stats.py +++ b/apiserver/bll/workers/stats.py @@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt from apiserver.bll.query import Builder as QueryBuilder from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context -from apiserver.timing_context import TimingContext log = config.logger(__file__) @@ -126,7 +125,7 @@ class WorkerStats: query_terms.append(QueryBuilder.terms("worker", request.worker_ids)) es_req["query"] = {"bool": {"must": query_terms}} - with translate_errors_context(), TimingContext("es", "get_worker_stats"): + with translate_errors_context(): data = self._search_company_stats(company_id, es_req) return self._extract_results(data, request.items, request.split_by_variant) @@ -223,9 +222,7 @@ class WorkerStats: "query": {"bool": {"must": must}}, } - with translate_errors_context(), TimingContext( - "es", "get_worker_activity_report" - ): + with translate_errors_context(): data = self._search_company_stats(company_id, es_req) if "aggregations" not in data: diff --git a/apiserver/service_repo/auth/auth.py b/apiserver/service_repo/auth/auth.py index 2dbd9aa..5ce1191 100644 --- a/apiserver/service_repo/auth/auth.py +++ b/apiserver/service_repo/auth/auth.py @@ -11,7 +11,6 @@ from apiserver.database.errors import translate_errors_context from apiserver.database.model.auth import User, Entities, Credentials from apiserver.database.model.company import Company from apiserver.database.utils import get_options -from apiserver.timing_context import TimingContext from .fixed_user import FixedUser from .identity import Identity from .payload import Payload, Token, Basic, AuthType @@ -88,9 +87,7 @@ def authorize_credentials(auth_data, service, action, call): query = Q(id=fixed_user.user_id) - with TimingContext("mongo", "user_by_cred"), translate_errors_context( - "authorizing request" - ): + with translate_errors_context("authorizing request"): user = User.objects(query).first() if not user: raise errors.unauthorized.InvalidCredentials( @@ -108,8 +105,7 @@ def authorize_credentials(auth_data, service, action, call): } ) - with TimingContext("mongo", "company_by_id"): - company = Company.objects(id=user.company).only("id", "name").first() + company = Company.objects(id=user.company).only("id", "name").first() if not company: raise errors.unauthorized.InvalidCredentials("invalid user company") diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 0e8cef1..00f672c 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -52,7 +52,6 @@ from apiserver.services.utils import ( unescape_metadata, escape_metadata, ) -from apiserver.timing_context import TimingContext log = config.logger(__file__) org_bll = OrgBLL() @@ -123,14 +122,13 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): conform_tag_fields(call, call.data) _process_include_subprojects(call.data) Metadata.escape_query_parameters(call) - with TimingContext("mongo", "models_get_all_ex"): - ret_params = {} - models = Model.get_many_with_join( - company=company_id, - query_dict=call.data, - allow_public=True, - ret_params=ret_params, - ) + ret_params = {} + models = Model.get_many_with_join( + company=company_id, + query_dict=call.data, + allow_public=True, + ret_params=ret_params, + ) conform_output_tags(call, models) unescape_metadata(call, models) @@ -139,10 +137,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): return model_ids = {model["id"] for model in models} - stats = ModelBLL.get_model_stats( - company=company_id, - model_ids=list(model_ids), - ) + stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),) for model in models: model["stats"] = stats.get(model["id"]) @@ -154,10 +149,9 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) Metadata.escape_query_parameters(call) - with TimingContext("mongo", "models_get_by_id_ex"): - models = Model.get_many_with_join( - company=company_id, query_dict=call.data, allow_public=True - ) + models = Model.get_many_with_join( + company=company_id, query_dict=call.data, allow_public=True + ) conform_output_tags(call, models) unescape_metadata(call, models) call.result.data = {"models": models} @@ -167,15 +161,14 @@ def get_by_id_ex(call: APICall, company_id, _): def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) Metadata.escape_query_parameters(call) - with TimingContext("mongo", "models_get_all"): - ret_params = {} - models = Model.get_many( - company=company_id, - parameters=call.data, - query_dict=call.data, - allow_public=True, - ret_params=ret_params, - ) + ret_params = {} + models = Model.get_many( + company=company_id, + parameters=call.data, + query_dict=call.data, + allow_public=True, + ret_params=ret_params, + ) conform_output_tags(call, models) unescape_metadata(call, models) call.result.data = {"models": models, **ret_params} @@ -414,9 +407,7 @@ def validate_task(company_id, fields: dict): def edit(call: APICall, company_id, _): model_id = call.data["model"] - model = ModelBLL.get_company_model_by_id( - company_id=company_id, model_id=model_id - ) + model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) fields = parse_model_fields(call, create_fields) fields = prepare_update_fields(call, company_id, fields) @@ -424,11 +415,7 @@ def edit(call: APICall, company_id, _): for key in fields: field = getattr(model, key, None) value = fields[key] - if ( - field - and isinstance(value, dict) - and isinstance(field, EmbeddedDocument) - ): + if field and isinstance(value, dict) and isinstance(field, EmbeddedDocument): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d @@ -448,13 +435,9 @@ def edit(call: APICall, company_id, _): if updated: new_project = fields.get("project", model.project) if new_project != model.project: - _reset_cached_tags( - company_id, projects=[new_project, model.project] - ) + _reset_cached_tags(company_id, projects=[new_project, model.project]) else: - _update_cached_tags( - company_id, project=model.project, fields=fields - ) + _update_cached_tags(company_id, project=model.project, fields=fields) conform_output_tags(call, fields) unescape_metadata(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) @@ -465,9 +448,7 @@ def edit(call: APICall, company_id, _): def _update_model(call: APICall, company_id, model_id=None): model_id = model_id or call.data["model"] - model = ModelBLL.get_company_model_by_id( - company_id=company_id, model_id=model_id - ) + model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id) data = prepare_update_fields(call, company_id, call.data) diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 8be6545..7746679 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -39,7 +39,6 @@ from apiserver.services.utils import ( get_tags_filter_dictionary, sort_tags_response, ) -from apiserver.timing_context import TimingContext org_bll = OrgBLL() project_bll = ProjectBLL() @@ -60,11 +59,8 @@ def get_by_id(call): project_id = call.data["project"] with translate_errors_context(): - with TimingContext("mongo", "projects_by_id"): - query = Q(id=project_id) & get_company_or_none_constraint( - call.identity.company - ) - project = Project.objects(query).first() + query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company) + project = Project.objects(query).first() if not project: raise errors.bad_request.InvalidProjectId(id=project_id) @@ -109,68 +105,65 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): _adjust_search_parameters( data, shallow_search=request.shallow_search, ) - with TimingContext("mongo", "projects_get_all"): - user_active_project_ids = None - if request.active_users: - ids, user_active_project_ids = project_bll.get_projects_with_active_user( - company=company_id, - users=request.active_users, - project_ids=requested_ids, - allow_public=allow_public, - ) - if not ids: - return {"projects": []} - data["id"] = ids - - ret_params = {} - projects: Sequence[dict] = Project.get_many_with_join( + user_active_project_ids = None + if request.active_users: + ids, user_active_project_ids = project_bll.get_projects_with_active_user( company=company_id, - query_dict=data, - query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids), + users=request.active_users, + project_ids=requested_ids, allow_public=allow_public, - ret_params=ret_params, ) - if not projects: - return {"projects": projects, **ret_params} + if not ids: + return {"projects": []} + data["id"] = ids - project_ids = list({project["id"] for project in projects}) - if request.check_own_contents: - contents = project_bll.calc_own_contents( - company=company_id, - project_ids=project_ids, - filter_=request.include_stats_filter, - users=request.active_users, - ) - for project in projects: - project.update(**contents.get(project["id"], {})) + ret_params = {} + projects: Sequence[dict] = Project.get_many_with_join( + company=company_id, + query_dict=data, + query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids), + allow_public=allow_public, + ret_params=ret_params, + ) + if not projects: + return {"projects": projects, **ret_params} - conform_output_tags(call, projects) - if request.include_stats: - stats, children = project_bll.get_project_stats( - company=company_id, - project_ids=project_ids, - specific_state=request.stats_for_state, - include_children=request.stats_with_children, - search_hidden=request.search_hidden, - filter_=request.include_stats_filter, - users=request.active_users, - user_active_project_ids=user_active_project_ids, - ) + project_ids = list({project["id"] for project in projects}) + if request.check_own_contents: + contents = project_bll.calc_own_contents( + company=company_id, + project_ids=project_ids, + filter_=request.include_stats_filter, + users=request.active_users, + ) + for project in projects: + project.update(**contents.get(project["id"], {})) - for project in projects: - project["stats"] = stats[project["id"]] - project["sub_projects"] = children[project["id"]] + conform_output_tags(call, projects) + if request.include_stats: + stats, children = project_bll.get_project_stats( + company=company_id, + project_ids=project_ids, + specific_state=request.stats_for_state, + include_children=request.stats_with_children, + search_hidden=request.search_hidden, + filter_=request.include_stats_filter, + users=request.active_users, + user_active_project_ids=user_active_project_ids, + ) - if request.include_dataset_stats: - dataset_stats = project_bll.get_dataset_stats( - company=company_id, - project_ids=project_ids, - users=request.active_users, - ) - for project in projects: - project["dataset_stats"] = dataset_stats.get(project["id"]) + for project in projects: + project["stats"] = stats[project["id"]] + project["sub_projects"] = children[project["id"]] - call.result.data = {"projects": projects, **ret_params} + if request.include_dataset_stats: + dataset_stats = project_bll.get_dataset_stats( + company=company_id, project_ids=project_ids, users=request.active_users, + ) + for project in projects: + project["dataset_stats"] = dataset_stats.get(project["id"]) + + call.result.data = {"projects": projects, **ret_params} @endpoint("projects.get_all") @@ -180,20 +173,19 @@ def get_all(call: APICall): _adjust_search_parameters( data, shallow_search=data.get("shallow_search", False), ) - with TimingContext("mongo", "projects_get_all"): - ret_params = {} - projects = Project.get_many( - company=call.identity.company, - query_dict=data, - query=_hidden_query( - search_hidden=data.get("search_hidden"), ids=data.get("id") - ), - parameters=data, - allow_public=True, - ret_params=ret_params, - ) - conform_output_tags(call, projects) - call.result.data = {"projects": projects, **ret_params} + ret_params = {} + projects = Project.get_many( + company=call.identity.company, + query_dict=data, + query=_hidden_query( + search_hidden=data.get("search_hidden"), ids=data.get("id") + ), + parameters=data, + allow_public=True, + ret_params=ret_params, + ) + conform_output_tags(call, projects) + call.result.data = {"projects": projects, **ret_params} @endpoint( diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 1676782..138f9ad 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -119,7 +119,6 @@ from apiserver.services.utils import ( escape_dict_field, unescape_dict_field, ) -from apiserver.timing_context import TimingContext from apiserver.utilities.dicts import nested_get from apiserver.utilities.partial_version import PartialVersion @@ -231,16 +230,15 @@ def get_all_ex(call: APICall, company_id, _): call_data = escape_execution_parameters(call) - with TimingContext("mongo", "task_get_all_ex"): - _process_include_subprojects(call_data) - ret_params = {} - tasks = Task.get_many_with_join( - company=company_id, - query_dict=call_data, - query=_hidden_query(call_data), - allow_public=True, - ret_params=ret_params, - ) + _process_include_subprojects(call_data) + ret_params = {} + tasks = Task.get_many_with_join( + company=company_id, + query_dict=call_data, + query=_hidden_query(call_data), + allow_public=True, + ret_params=ret_params, + ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks, **ret_params} @@ -251,10 +249,9 @@ def get_by_id_ex(call: APICall, company_id, _): call_data = escape_execution_parameters(call) - with TimingContext("mongo", "task_get_by_id_ex"): - tasks = Task.get_many_with_join( - company=company_id, query_dict=call_data, allow_public=True, - ) + tasks = Task.get_many_with_join( + company=company_id, query_dict=call_data, allow_public=True, + ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks} @@ -266,16 +263,15 @@ def get_all(call: APICall, company_id, _): call_data = escape_execution_parameters(call) - with TimingContext("mongo", "task_get_all"): - ret_params = {} - tasks = Task.get_many( - company=company_id, - parameters=call_data, - query_dict=call_data, - query=_hidden_query(call_data), - allow_public=True, - ret_params=ret_params, - ) + ret_params = {} + tasks = Task.get_many( + company=company_id, + parameters=call_data, + query_dict=call_data, + query=_hidden_query(call_data), + allow_public=True, + ret_params=ret_params, + ) unprepare_from_saved(call, tasks) call.result.data = {"tasks": tasks, **ret_params} @@ -487,12 +483,11 @@ def prepare_create_fields( def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]: with translate_errors_context( field_does_not_exist_cls=errors.bad_request.ValidationError - ), TimingContext("code", "parse_call"): + ): fields = prepare_create_fields(call, **kwargs) task = task_bll.create(call, fields) - with TimingContext("code", "validate"): - task_bll.validate(task) + task_bll.validate(task) return task, fields @@ -525,7 +520,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]): def create(call: APICall, company_id, req_model: CreateRequest): task, fields = _validate_and_get_task_from_call(call) - with translate_errors_context(), TimingContext("mongo", "save_task"): + with translate_errors_context(): task.save() _update_cached_tags(company_id, project=task.project, fields=fields) update_project_time(task.project) @@ -708,7 +703,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): with translate_errors_context( field_does_not_exist_cls=errors.bad_request.ValidationError - ), TimingContext("code", "parse_and_validate"): + ): fields = prepare_create_fields( call, valid_fields=edit_fields, output=task.output, previous_task=task ) diff --git a/apiserver/timing_context.py b/apiserver/timing_context.py deleted file mode 100644 index 005131c..0000000 --- a/apiserver/timing_context.py +++ /dev/null @@ -1,15 +0,0 @@ -class TimingStats: - @classmethod - def aggregate(cls): - return {} - - -class TimingContext: - def __init__(self, *_, **__): - pass - - def __enter__(self): - return self - - def __exit__(self, *args): - pass