From 14d18a7abad6b7bbcd4bcb63fb58e3ba6bf5dddc Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Wed, 26 Jul 2023 18:19:41 +0300
Subject: [PATCH] Remove obsolete duration field

---
 apiserver/bll/util.py                 | 55 ---------------------------
 apiserver/database/model/task/task.py |  2 +-
 apiserver/services/tasks.py           | 23 ++---------
 3 files changed, 4 insertions(+), 76 deletions(-)

diff --git a/apiserver/bll/util.py b/apiserver/bll/util.py
index 62c5e41..410ec38 100644
--- a/apiserver/bll/util.py
+++ b/apiserver/bll/util.py
@@ -4,9 +4,6 @@ from concurrent.futures.thread import ThreadPoolExecutor
 from typing import (
     Optional,
     Callable,
-    Dict,
-    Any,
-    Set,
     Iterable,
     Tuple,
     Sequence,
@@ -16,61 +13,9 @@ from typing import (
 from boltons import iterutils
 
 from apiserver.apierrors import APIError
-from apiserver.database.model import AttributedDocument
 from apiserver.database.model.settings import Settings
 
 
-class SetFieldsResolver:
-    """
-    The class receives set fields dictionary
-    and for the set fields that require 'min' or 'max'
-    operation replace them with a simple set in case the
-    DB document does not have these fields set
-    """
-
-    SET_MODIFIERS = ("min", "max")
-
-    def __init__(self, set_fields: Dict[str, Any]):
-        self.orig_fields = {}
-        self.fields = {}
-        self.add_fields(**set_fields)
-
-    def add_fields(self, **set_fields: Any):
-        self.orig_fields.update(set_fields)
-        self.fields.update(
-            {
-                f: fname
-                for f, modifier, dunder, fname in (
-                    (f,) + f.partition("__") for f in set_fields.keys()
-                )
-                if dunder and modifier in self.SET_MODIFIERS
-            }
-        )
-
-    def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
-        if name in self.fields and doc.get_field_value(self.fields[name]) is None:
-            return self.fields[name]
-        return name
-
-    def get_fields(self, doc: AttributedDocument):
-        """
-        For the given document return the set fields instructions
-        with min/max operations replaced with a single set in case
-        the document does not have the field set
-        """
-        return {
-            self._get_updated_name(doc, name): value
-            for name, value in self.orig_fields.items()
-        }
-
-    def get_names(self) -> Set[str]:
-        """
-        Returns the names of the fields that had min/max modifiers
-        in the format suitable for projection (dot separated)
-        """
-        return set(name.replace("__", ".") for name in self.fields.values())
-
-
 @functools.lru_cache()
 def get_server_uuid() -> Optional[str]:
     return Settings.get_by_key("server.uuid")
diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py
index 307f419..a2323dd 100644
--- a/apiserver/database/model/task/task.py
+++ b/apiserver/database/model/task/task.py
@@ -271,7 +271,7 @@ class Task(AttributedDocument):
     unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
     metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
     company_origin = StringField(exclude_by_default=True)
-    duration = IntField()  # task duration in seconds
+    duration = IntField()  # obsolete, do not use
     hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
     configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
     runtime = SafeDictField(default=dict)
diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py
index d20d207..776f983 100644
--- a/apiserver/services/tasks.py
+++ b/apiserver/services/tasks.py
@@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import (
     move_tasks_to_trash,
 )
 from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
-from apiserver.bll.util import SetFieldsResolver, run_batch_operation
+from apiserver.bll.util import run_batch_operation
 from apiserver.database.errors import translate_errors_context
 from apiserver.database.model import EntityVisibility
 from apiserver.database.model.task.output import Output
@@ -141,30 +141,13 @@ project_bll = ProjectBLL()
 def set_task_status_from_call(
     request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
 ) -> dict:
-    fields_resolver = SetFieldsResolver(set_fields)
     task = TaskBLL.get_task_with_access(
         request.task,
         company_id=company_id,
-        only=tuple(
-            {"status", "project", "started", "duration"} | fields_resolver.get_names()
-        ),
+        only=("id", "status", "project"),
         requires_write_access=True,
     )
 
-    if "duration" not in fields_resolver.get_names():
-        if new_status == Task.started:
-            fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
-        elif new_status in (
-            TaskStatus.completed,
-            TaskStatus.failed,
-            TaskStatus.stopped,
-        ):
-            fields_resolver.add_fields(
-                duration=int((task.started - datetime.utcnow()).total_seconds())
-                if task.started
-                else 0
-            )
-
     status_reason = request.status_reason
     status_message = request.status_message
     force = request.force
@@ -175,7 +158,7 @@ def set_task_status_from_call(
         status_message=status_message,
         force=force,
         user_id=user_id,
-    ).execute(**fields_resolver.get_fields(task))
+    ).execute(**set_fields)
 
 
 @endpoint("tasks.get_by_id", request_data_model=TaskRequest)