Removed stub timing context

This commit is contained in:
allegroai 2022-09-29 19:37:15 +03:00
parent 0c9e2f92ee
commit de1f823213
23 changed files with 594 additions and 741 deletions

View File

@ -39,7 +39,6 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
@ -97,7 +96,7 @@ class EventBLL(object):
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
with translate_errors_context():
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
@ -228,7 +227,6 @@ class EventBLL(object):
with translate_errors_context():
if actions:
chunk_size = 500
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
elasticsearch.helpers.streaming_bulk(
@ -425,7 +423,7 @@ class EventBLL(object):
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "task_log_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
size = min(batch_size, 10000)
@ -438,7 +436,7 @@ class EventBLL(object):
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@ -468,9 +466,7 @@ class EventBLL(object):
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@ -508,9 +504,7 @@ class EventBLL(object):
"query": query,
}
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
@ -538,7 +532,7 @@ class EventBLL(object):
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = EventType.metrics_plot
@ -602,7 +596,7 @@ class EventBLL(object):
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_plots"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@ -720,7 +714,7 @@ class EventBLL(object):
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
@ -768,7 +762,7 @@ class EventBLL(object):
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@ -793,9 +787,7 @@ class EventBLL(object):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@ -822,9 +814,7 @@ class EventBLL(object):
"query": query,
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
metrics = {}
@ -851,9 +841,7 @@ class EventBLL(object):
]
}
}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@ -899,9 +887,7 @@ class EventBLL(object):
},
"_source": {"excludes": []},
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
metrics = []
@ -947,7 +933,7 @@ class EventBLL(object):
"_source": ["iter", "value"],
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
@ -990,7 +976,7 @@ class EventBLL(object):
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
@ -1022,7 +1008,7 @@ class EventBLL(object):
)
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,
@ -1048,7 +1034,7 @@ class EventBLL(object):
):
return 0
with translate_errors_context(), TimingContext("es", "clear_task_log"):
with translate_errors_context():
must = [{"term": {"task": task_id}}]
sort = None
if threshold_sec:
@ -1082,9 +1068,7 @@ class EventBLL(object):
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
with translate_errors_context():
es_res = delete_company_events(
es=self.es,
company_id=company_id,

View File

@ -8,7 +8,6 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
@ -21,7 +20,7 @@ class EventType(Enum):
all = "*"
SINGLE_SCALAR_ITERATION = -2**31
SINGLE_SCALAR_ITERATION = -(2 ** 31)
MetricVariants = Mapping[str, Sequence[str]]
@ -80,9 +79,7 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(
index=es_index, body=body, conflicts="proceed", **kwargs
)
return es.delete_by_query(index=es_index, body=body, conflicts="proceed", **kwargs)
def count_company_events(
@ -116,9 +113,7 @@ def get_max_metric_and_variant_counts(
"query": query,
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
}
with translate_errors_context(), TimingContext(
"es", "get_max_metric_and_variant_counts"
):
with translate_errors_context():
es_res = search_company_events(
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
)

View File

@ -25,7 +25,6 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
log = config.logger(__file__)
@ -180,7 +179,6 @@ class EventMetrics:
):
return {}
with TimingContext("es", "get_task_single_value_metrics"):
task_events = self._get_task_single_value_metrics(company_id, task_ids)
def _get_value(event: dict):
@ -277,9 +275,7 @@ class EventMetrics:
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@ -312,7 +308,6 @@ class EventMetrics:
},
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
@ -366,9 +361,7 @@ class EventMetrics:
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
@ -493,7 +486,6 @@ class EventMetrics:
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)

View File

@ -17,7 +17,6 @@ from apiserver.bll.event.event_common import (
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
@ -76,7 +75,7 @@ class EventsIterator:
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
with translate_errors_context():
es_result = count_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
@ -113,7 +112,7 @@ class EventsIterator:
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
with translate_errors_context():
es_result = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)

View File

@ -21,7 +21,6 @@ from apiserver.bll.event.event_common import (
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
@ -174,14 +173,9 @@ class HistorySampleIterator(abc.ABC):
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
@ -235,14 +229,9 @@ class HistorySampleIterator(abc.ABC):
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
@ -335,9 +324,7 @@ class HistorySampleIterator(abc.ABC):
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_history_sample_for_variant"
):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
@ -421,9 +408,7 @@ class HistorySampleIterator(abc.ABC):
},
}
with translate_errors_context(), TimingContext(
"es", "get_history_sample_iterations"
):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args,)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:

View File

@ -19,14 +19,14 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition, get_max_metric_and_variant_counts,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantState(Base):
@ -226,7 +226,9 @@ class MetricEventsIterator:
pass
@abc.abstractmethod
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
def _get_variant_state_aggs(
self,
) -> Tuple[dict, Callable[[dict, VariantState], None]]:
pass
def _init_metric_states_for_task(
@ -268,14 +270,18 @@ class MetricEventsIterator:
"size": max_variants,
"order": {"_key": "asc"},
},
**({"aggs": variant_state_aggs} if variant_state_aggs else {}),
**(
{"aggs": variant_state_aggs}
if variant_state_aggs
else {}
),
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
if "aggregations" not in es_res:
return []
@ -383,12 +389,9 @@ class MetricEventsIterator:
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metric_events"):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
)
if "aggregations" not in es_res:
return task_state.task, []

View File

@ -11,7 +11,6 @@ from apiserver.utilities.parameter_key_escaper import (
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@ -42,7 +41,6 @@ class Metadata:
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
@ -55,7 +53,6 @@ class Metadata:
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1

View File

@ -29,7 +29,6 @@ from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
from apiserver.database.utils import get_options, get_company_or_none_constraint
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from .sub_projects import (
_reposition_project_with_children,
@ -57,7 +56,6 @@ class ProjectBLL:
Remove the source project
Return the amounts of moved entities and subprojects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
source=source_id
@ -127,7 +125,6 @@ class ProjectBLL:
it should be writable. The source location should be writable too.
Return the number of moved projects + set of all the affected project ids
"""
with TimingContext("mongo", "move_project"):
project = Project.get(company, project_id)
old_parent_id = project.parent
old_parent = (
@ -171,7 +168,6 @@ class ProjectBLL:
@classmethod
def update(cls, company: str, project_id: str, **fields):
with TimingContext("mongo", "projects_update"):
project = Project.get_for_writing(company=company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@ -191,9 +187,9 @@ class ProjectBLL:
if new_name:
old_name = project.name
project.name = new_name
children = _get_sub_projects(
[project.id], _only=("id", "name", "path")
)[project.id]
children = _get_sub_projects([project.id], _only=("id", "name", "path"))[
project.id
]
_update_subproject_names(
project=project, children=children, old_name=old_name
)
@ -301,7 +297,6 @@ class ProjectBLL:
"""
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
"""
with TimingContext("mongo", "move_under_project"):
project = cls.find_or_create(
user=user,
company=company,
@ -716,7 +711,6 @@ class ProjectBLL:
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
@ -741,7 +735,6 @@ class ProjectBLL:
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
@ -930,7 +923,6 @@ class ProjectBLL:
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {

View File

@ -14,7 +14,6 @@ from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from apiserver.timing_context import TimingContext
from .sub_projects import _ids_with_children
log = config.logger(__file__)
@ -88,11 +87,8 @@ def delete_project(
)
if not delete_contents:
with TimingContext("mongo", "update_children"):
for cls in (Model, Task):
updated_count = cls.objects(project__in=project_ids).update(
project=None
)
updated_count = cls.objects(project__in=project_ids).update(project=None)
res = DeleteProjectResult(disassociated_tasks=updated_count)
else:
deleted_models, model_urls = _delete_models(projects=project_ids)
@ -127,7 +123,6 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
return 0, set(), set()
task_ids = {t.id for t in tasks}
with TimingContext("mongo", "delete_tasks_update_children"):
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
@ -154,7 +149,6 @@ def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
Delete project models and update the tasks from other projects
that reference them to reference None.
"""
with TimingContext("mongo", "delete_models"):
models = Model.objects(project__in=projects).only("task", "id", "uri")
if not models:
return 0, set()

View File

@ -14,7 +14,6 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@ -182,7 +181,7 @@ class QueueMetrics:
"aggs": self._get_dates_agg(interval),
}
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
with translate_errors_context():
res = self._search_company_metrics(company_id, es_req)
if "aggregations" not in res:
@ -285,6 +284,7 @@ class MetricsRefresher:
if not queue_metrics:
from .queue_bll import QueueBLL
queue_metrics = QueueBLL().metrics
sleep(10)

View File

@ -4,7 +4,6 @@ from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
from apiserver import database
from apiserver.timing_context import TimingContext
T = TypeVar("T")
@ -31,19 +30,16 @@ class RedisCacheManager(Generic[T]):
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):

View File

@ -5,7 +5,6 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@ -53,12 +52,7 @@ class Artifacts:
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
artifacts = {
get_artifact_id(a): Artifact(**a)
@ -79,12 +73,7 @@ class Artifacts:
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update(
company_id=company_id,
task_id=task_id,
force=force,
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
artifact_ids = [
get_artifact_id(a)

View File

@ -15,7 +15,6 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@ -68,7 +67,6 @@ class HyperParams:
hyperparams: Sequence[HyperParamKey],
force: bool,
) -> int:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
@ -108,7 +106,6 @@ class HyperParams:
replace_hyperparams: str,
force: bool,
) -> int:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id,
@ -123,9 +120,7 @@ class HyperParams:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
] = value
update_cmds[f"set__hyperparams__{mongoengine_safe(section)}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
@ -191,7 +186,6 @@ class HyperParams:
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline)
return {
@ -212,10 +206,7 @@ class HyperParams:
replace_configuration: bool,
force: bool,
) -> int:
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
update_cmds = dict()
configuration = {
@ -234,10 +225,7 @@ class HyperParams:
def delete_configuration(
cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
) -> int:
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force
)
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1

View File

@ -35,7 +35,6 @@ from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@ -66,7 +65,6 @@ class TaskBLL:
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
with TimingContext("mongo", "task_with_access"):
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
@ -88,7 +86,6 @@ class TaskBLL:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
@ -111,7 +108,7 @@ class TaskBLL:
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
with translate_errors_context():
ids = set(task_ids)
q = Task.get_many(
company=company_id,
@ -260,7 +257,6 @@ class TaskBLL:
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
@ -281,9 +277,7 @@ class TaskBLL:
system_tags=system_tags or clean_system_tags(task.system_tags),
type=task.type,
script=task.script,
output=Output(destination=task.output.destination)
if task.output
else None,
output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,

View File

@ -15,8 +15,12 @@ from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.database.model.url_to_delete import StorageType, UrlToDelete, FileType, DeletionStatus
from apiserver.timing_context import TimingContext
from apiserver.database.model.url_to_delete import (
StorageType,
UrlToDelete,
FileType,
DeletionStatus,
)
from apiserver.database.utils import id as db_id
event_bll = EventBLL()
@ -65,7 +69,6 @@ class CleanupResult:
def collect_plot_image_urls(company: str, task: str) -> Set[str]:
urls = set()
next_scroll_id = None
with TimingContext("es", "collect_plot_image_urls"):
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task, scroll_id=next_scroll_id
@ -89,9 +92,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
urls = set()
while True:
res, after_key = event_bll.get_debug_image_urls(
company_id=company,
task_id=task,
after_key=after_key,
company_id=company, task_id=task, after_key=after_key,
)
urls.update(res)
if not after_key:
@ -198,10 +199,7 @@ def cleanup_task(
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
)
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
deleted_models = 0
updated_models = 0
@ -256,7 +254,6 @@ def verify_task_children_and_ouptuts(
task, force: bool
) -> Tuple[Sequence[Model], Sequence[Model], Set[str]]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()

View File

@ -9,7 +9,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@ -55,7 +54,7 @@ class ChangeStatusRequest(object):
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context(), TimingContext("mongo", "task_status"):
with translate_errors_context():
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)

View File

@ -25,7 +25,6 @@ from apiserver.database.model.project import Project
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
from .stats import WorkerStats
@ -109,7 +108,6 @@ class WorkerBLL:
:param worker: worker ID
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
"""
with TimingContext("redis", "workers_unregister"):
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
@ -117,7 +115,12 @@ class WorkerBLL:
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
self,
company_id: str,
user_id: str,
ip: str,
report: StatusReportRequest,
tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
@ -176,7 +179,9 @@ class WorkerBLL:
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.project = IdNameEntry(
id=project.id, name=project.name
)
entry.last_report_time = now
except APIError:
@ -323,7 +328,6 @@ class WorkerBLL:
"""
key = self._get_worker_key(company_id, user_id, worker)
with TimingContext("redis", "get_worker"):
data = self.redis.get(key)
if data:
@ -367,7 +371,6 @@ class WorkerBLL:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
match = self._get_worker_key(company, user, worker_id)
with TimingContext("redis", "workers_get_all"):
for r in self.redis.scan_iter(match):
data = self.redis.get(r)
if data:

View File

@ -8,7 +8,6 @@ from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatIt
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@ -126,7 +125,7 @@ class WorkerStats:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
@ -223,9 +222,7 @@ class WorkerStats:
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext(
"es", "get_worker_activity_report"
):
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
if "aggregations" not in data:

View File

@ -11,7 +11,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
@ -88,9 +87,7 @@ def authorize_credentials(auth_data, service, action, call):
query = Q(id=fixed_user.user_id)
with TimingContext("mongo", "user_by_cred"), translate_errors_context(
"authorizing request"
):
with translate_errors_context("authorizing request"):
user = User.objects(query).first()
if not user:
raise errors.unauthorized.InvalidCredentials(
@ -108,7 +105,6 @@ def authorize_credentials(auth_data, service, action, call):
}
)
with TimingContext("mongo", "company_by_id"):
company = Company.objects(id=user.company).only("id", "name").first()
if not company:

View File

@ -52,7 +52,6 @@ from apiserver.services.utils import (
unescape_metadata,
escape_metadata,
)
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
@ -123,7 +122,6 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
_process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
@ -139,10 +137,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
for model in models:
model["stats"] = stats.get(model["id"])
@ -154,7 +149,6 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
@ -167,7 +161,6 @@ def get_by_id_ex(call: APICall, company_id, _):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
@ -414,9 +407,7 @@ def validate_task(company_id, fields: dict):
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
@ -424,11 +415,7 @@ def edit(call: APICall, company_id, _):
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
if field and isinstance(value, dict) and isinstance(field, EmbeddedDocument):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
@ -448,13 +435,9 @@ def edit(call: APICall, company_id, _):
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
_update_cached_tags(company_id, project=model.project, fields=fields)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@ -465,9 +448,7 @@ def edit(call: APICall, company_id, _):
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)

View File

@ -39,7 +39,6 @@ from apiserver.services.utils import (
get_tags_filter_dictionary,
sort_tags_response,
)
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
project_bll = ProjectBLL()
@ -60,10 +59,7 @@ def get_by_id(call):
project_id = call.data["project"]
with translate_errors_context():
with TimingContext("mongo", "projects_by_id"):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
@ -109,7 +105,6 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
with TimingContext("mongo", "projects_get_all"):
user_active_project_ids = None
if request.active_users:
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
@ -163,9 +158,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
if request.include_dataset_stats:
dataset_stats = project_bll.get_dataset_stats(
company=company_id,
project_ids=project_ids,
users=request.active_users,
company=company_id, project_ids=project_ids, users=request.active_users,
)
for project in projects:
project["dataset_stats"] = dataset_stats.get(project["id"])
@ -180,7 +173,6 @@ def get_all(call: APICall):
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
)
with TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,

View File

@ -119,7 +119,6 @@ from apiserver.services.utils import (
escape_dict_field,
unescape_dict_field,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.partial_version import PartialVersion
@ -231,7 +230,6 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
@ -251,7 +249,6 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
@ -266,7 +263,6 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all"):
ret_params = {}
tasks = Task.get_many(
company=company_id,
@ -487,11 +483,10 @@ def prepare_create_fields(
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_call"):
):
fields = prepare_create_fields(call, **kwargs)
task = task_bll.create(call, fields)
with TimingContext("code", "validate"):
task_bll.validate(task)
return task, fields
@ -525,7 +520,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
def create(call: APICall, company_id, req_model: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(), TimingContext("mongo", "save_task"):
with translate_errors_context():
task.save()
_update_cached_tags(company_id, project=task.project, fields=fields)
update_project_time(task.project)
@ -708,7 +703,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_and_validate"):
):
fields = prepare_create_fields(
call, valid_fields=edit_fields, output=task.output, previous_task=task
)

View File

@ -1,15 +0,0 @@
class TimingStats:
@classmethod
def aggregate(cls):
return {}
class TimingContext:
def __init__(self, *_, **__):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass