Remove obsolete duration field

This commit is contained in:
allegroai 2023-07-26 18:19:41 +03:00
parent a7ed46979f
commit 14d18a7aba
3 changed files with 4 additions and 76 deletions

View File

@ -4,9 +4,6 @@ from concurrent.futures.thread import ThreadPoolExecutor
from typing import (
Optional,
Callable,
Dict,
Any,
Set,
Iterable,
Tuple,
Sequence,
@ -16,61 +13,9 @@ from typing import (
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model import AttributedDocument
from apiserver.database.model.settings import Settings
class SetFieldsResolver:
"""
The class receives set fields dictionary
and for the set fields that require 'min' or 'max'
operation replace them with a simple set in case the
DB document does not have these fields set
"""
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = {}
self.fields = {}
self.add_fields(**set_fields)
def add_fields(self, **set_fields: Any):
self.orig_fields.update(set_fields)
self.fields.update(
{
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
return self.fields[name]
return name
def get_fields(self, doc: AttributedDocument):
"""
For the given document return the set fields instructions
with min/max operations replaced with a single set in case
the document does not have the field set
"""
return {
self._get_updated_name(doc, name): value
for name, value in self.orig_fields.items()
}
def get_names(self) -> Set[str]:
"""
Returns the names of the fields that had min/max modifiers
in the format suitable for projection (dot separated)
"""
return set(name.replace("__", ".") for name in self.fields.values())
@functools.lru_cache()
def get_server_uuid() -> Optional[str]:
return Settings.get_by_key("server.uuid")

View File

@ -271,7 +271,7 @@ class Task(AttributedDocument):
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
duration = IntField() # obsolete, do not use
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)

View File

@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import (
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.bll.util import run_batch_operation
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output
@ -141,30 +141,13 @@ project_bll = ProjectBLL()
def set_task_status_from_call(
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
) -> dict:
fields_resolver = SetFieldsResolver(set_fields)
task = TaskBLL.get_task_with_access(
request.task,
company_id=company_id,
only=tuple(
{"status", "project", "started", "duration"} | fields_resolver.get_names()
),
only=("id", "status", "project"),
requires_write_access=True,
)
if "duration" not in fields_resolver.get_names():
if new_status == Task.started:
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
elif new_status in (
TaskStatus.completed,
TaskStatus.failed,
TaskStatus.stopped,
):
fields_resolver.add_fields(
duration=int((task.started - datetime.utcnow()).total_seconds())
if task.started
else 0
)
status_reason = request.status_reason
status_message = request.status_message
force = request.force
@ -175,7 +158,7 @@ def set_task_status_from_call(
status_message=status_message,
force=force,
user_id=user_id,
).execute(**fields_resolver.get_fields(task))
).execute(**set_fields)
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)