From 14d18a7abad6b7bbcd4bcb63fb58e3ba6bf5dddc Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 26 Jul 2023 18:19:41 +0300 Subject: [PATCH] Remove obsolete duration field --- apiserver/bll/util.py | 55 --------------------------- apiserver/database/model/task/task.py | 2 +- apiserver/services/tasks.py | 23 ++--------- 3 files changed, 4 insertions(+), 76 deletions(-) diff --git a/apiserver/bll/util.py b/apiserver/bll/util.py index 62c5e41..410ec38 100644 --- a/apiserver/bll/util.py +++ b/apiserver/bll/util.py @@ -4,9 +4,6 @@ from concurrent.futures.thread import ThreadPoolExecutor from typing import ( Optional, Callable, - Dict, - Any, - Set, Iterable, Tuple, Sequence, @@ -16,61 +13,9 @@ from typing import ( from boltons import iterutils from apiserver.apierrors import APIError -from apiserver.database.model import AttributedDocument from apiserver.database.model.settings import Settings -class SetFieldsResolver: - """ - The class receives set fields dictionary - and for the set fields that require 'min' or 'max' - operation replace them with a simple set in case the - DB document does not have these fields set - """ - - SET_MODIFIERS = ("min", "max") - - def __init__(self, set_fields: Dict[str, Any]): - self.orig_fields = {} - self.fields = {} - self.add_fields(**set_fields) - - def add_fields(self, **set_fields: Any): - self.orig_fields.update(set_fields) - self.fields.update( - { - f: fname - for f, modifier, dunder, fname in ( - (f,) + f.partition("__") for f in set_fields.keys() - ) - if dunder and modifier in self.SET_MODIFIERS - } - ) - - def _get_updated_name(self, doc: AttributedDocument, name: str) -> str: - if name in self.fields and doc.get_field_value(self.fields[name]) is None: - return self.fields[name] - return name - - def get_fields(self, doc: AttributedDocument): - """ - For the given document return the set fields instructions - with min/max operations replaced with a single set in case - the document does not have the field set - """ - return { - self._get_updated_name(doc, name): value - for name, value in self.orig_fields.items() - } - - def get_names(self) -> Set[str]: - """ - Returns the names of the fields that had min/max modifiers - in the format suitable for projection (dot separated) - """ - return set(name.replace("__", ".") for name in self.fields.values()) - - @functools.lru_cache() def get_server_uuid() -> Optional[str]: return Settings.get_by_key("server.uuid") diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 307f419..a2323dd 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -271,7 +271,7 @@ class Task(AttributedDocument): unique_metrics = ListField(StringField(required=True), exclude_by_default=True) metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) company_origin = StringField(exclude_by_default=True) - duration = IntField() # task duration in seconds + duration = IntField() # obsolete, do not use hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem))) configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem)) runtime = SafeDictField(default=dict) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index d20d207..776f983 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import ( move_tasks_to_trash, ) from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix -from apiserver.bll.util import SetFieldsResolver, run_batch_operation +from apiserver.bll.util import run_batch_operation from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility from apiserver.database.model.task.output import Output @@ -141,30 +141,13 @@ project_bll = ProjectBLL() def set_task_status_from_call( request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields ) -> dict: - fields_resolver = SetFieldsResolver(set_fields) task = TaskBLL.get_task_with_access( request.task, company_id=company_id, - only=tuple( - {"status", "project", "started", "duration"} | fields_resolver.get_names() - ), + only=("id", "status", "project"), requires_write_access=True, ) - if "duration" not in fields_resolver.get_names(): - if new_status == Task.started: - fields_resolver.add_fields(min__duration=max(0, task.duration or 0)) - elif new_status in ( - TaskStatus.completed, - TaskStatus.failed, - TaskStatus.stopped, - ): - fields_resolver.add_fields( - duration=int((task.started - datetime.utcnow()).total_seconds()) - if task.started - else 0 - ) - status_reason = request.status_reason status_message = request.status_message force = request.force @@ -175,7 +158,7 @@ def set_task_status_from_call( status_message=status_message, force=force, user_id=user_id, - ).execute(**fields_resolver.get_fields(task)) + ).execute(**set_fields) @endpoint("tasks.get_by_id", request_data_model=TaskRequest)