diff --git a/apiserver/apimodels/base.py b/apiserver/apimodels/base.py index 1d900f7..d4db847 100644 --- a/apiserver/apimodels/base.py +++ b/apiserver/apimodels/base.py @@ -82,7 +82,7 @@ class DictField(fields.BaseField): """Cast value to proper collection.""" result = self.get_default_value() - if not values: + if values is None: return result if not self.value_types or not isinstance(values, dict): diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 407dca6..0ad70b1 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -127,7 +127,11 @@ class ProjectBLL: project_name=project_name, description="Auto-generated during move", ) - - entity_cls.objects(company=company, id__in=ids).update(set__project=project) + extra = ( + {"set__last_change": datetime.utcnow()} + if hasattr(entity_cls, "last_change") + else {} + ) + entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra) return project diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index ee8c3db..11db064 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -1,10 +1,9 @@ -from datetime import datetime from hashlib import md5 from operator import itemgetter from typing import Sequence from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId -from apiserver.bll.task.utils import get_task_for_update +from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact from apiserver.timing_context import TimingContext from apiserver.utilities.dicts import nested_get, nested_set @@ -70,7 +69,7 @@ class Artifacts: f"set__execution__artifacts__{mongoengine_safe(name)}": value for name, value in artifacts.items() } - return task.update(**update_cmds, last_update=datetime.utcnow()) + return update_task(task, update_cmds=update_cmds) @classmethod def delete_artifacts( @@ -95,4 +94,4 @@ class Artifacts: f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) } - return task.update(**delete_cmds, last_update=datetime.utcnow()) + return update_task(task, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index bd5c043..eb8e8bb 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -1,4 +1,3 @@ -from datetime import datetime from itertools import chain from operator import attrgetter from typing import Sequence, Dict @@ -13,7 +12,7 @@ from apiserver.apimodels.tasks import ( Configuration, ) from apiserver.bll.task import TaskBLL -from apiserver.bll.task.utils import get_task_for_update +from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.config_repo import config from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem from apiserver.timing_context import TimingContext @@ -96,7 +95,9 @@ class HyperParams: name = ParameterKeyEscaper.escape(item.name) delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 - return task.update(**delete_cmds, last_update=datetime.utcnow()) + return update_task( + task, update_cmds=delete_cmds, set_last_update=not properties_only + ) @classmethod def edit_params( @@ -132,7 +133,9 @@ class HyperParams: f"set__hyperparams__{section}__{mongoengine_safe(name)}" ] = value - return task.update(**update_cmds, last_update=datetime.utcnow()) + return update_task( + task, update_cmds=update_cmds, set_last_update=not properties_only + ) @classmethod def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]: @@ -223,7 +226,7 @@ class HyperParams: for name, value in configuration.items(): update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value - return task.update(**update_cmds, last_update=datetime.utcnow()) + return update_task(task, update_cmds=update_cmds) @classmethod def delete_configuration( @@ -239,4 +242,4 @@ class HyperParams: for name in set(configuration) } - return task.update(**delete_cmds, last_update=datetime.utcnow()) + return update_task(task, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index aa3873e..bbc106d 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -145,6 +145,7 @@ class TaskBLL: company=identity.company, created=now, last_update=now, + last_change=now, **fields, ) @@ -237,6 +238,7 @@ class TaskBLL: company=company_id, created=now, last_update=now, + last_change=now, name=name or task.name, comment=comment or task.comment, parent=parent or task.parent, @@ -367,7 +369,10 @@ class TaskBLL: **extra_updates, } Task.objects(id=task.id, company=company_id).update( - upsert=False, last_update=last_update, **updates + upsert=False, + last_update=last_update, + last_change=last_update, + **updates, ) @staticmethod @@ -653,11 +658,7 @@ class TaskBLL: @classmethod def dequeue_and_change_status( - cls, - task: Task, - company_id: str, - status_message: str, - status_reason: str, + cls, task: Task, company_id: str, status_message: str, status_reason: str, ): cls.dequeue(task, company_id) diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 82a3bc6..e2a23a6 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -43,6 +43,7 @@ class ChangeStatusRequest(object): status_message=self.status_message, status_changed=now, last_update=now, + last_change=now, ) if self.new_status == TaskStatus.queued: @@ -194,3 +195,11 @@ def get_task_for_update( expected=TaskStatus.created, status=task.status ) return task + + +def update_task(task: Task, update_cmds: dict, set_last_update: bool = True): + now = datetime.utcnow() + last_updates = dict(last_change=now) + if set_last_update: + last_updates.update(last_update=now) + return task.update(**update_cmds, **last_updates) diff --git a/apiserver/bll/workers/__init__.py b/apiserver/bll/workers/__init__.py index bd8a5c1..3ecf6e9 100644 --- a/apiserver/bll/workers/__init__.py +++ b/apiserver/bll/workers/__init__.py @@ -164,6 +164,7 @@ class WorkerBLL: last_worker=report.worker, last_worker_report=now, last_update=now, + last_change=now, ) # modify(new=True, ...) returns the modified object task = Task.objects(**query).modify(new=True, **update) diff --git a/apiserver/database/model/auth.py b/apiserver/database/model/auth.py index f599afd..341e2ea 100644 --- a/apiserver/database/model/auth.py +++ b/apiserver/database/model/auth.py @@ -52,7 +52,7 @@ class Credentials(EmbeddedDocument): class User(DbModelMixin, AuthDocument): - meta = {"db_alias": Database.auth, "strict": strict} + meta = {"db_alias": Database.auth, "strict": strict, "indexes": ["email"]} id = StringField(primary_key=True) name = StringField() diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index 66f0f39..5ff3267 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -701,14 +701,24 @@ class GetMixin(PropsMixin): class UpdateMixin(object): + __user_set_allowed_fields = None + __locked_when_published_fields = None + @classmethod def user_set_allowed(cls): - res = getattr(cls, "__user_set_allowed_fields", None) - if res is None: - res = cls.__user_set_allowed_fields = get_fields_choices( - cls, "user_set_allowed" + if cls.__user_set_allowed_fields is None: + cls.__user_set_allowed_fields = dict( + get_fields_choices(cls, "user_set_allowed") ) - return res + return cls.__user_set_allowed_fields + + @classmethod + def locked_when_published(cls): + if cls.__locked_when_published_fields is None: + cls.__locked_when_published_fields = dict( + get_fields_choices(cls, "locked_when_published") + ) + return cls.__locked_when_published_fields @classmethod def get_safe_update_dict(cls, fields): diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index fb5c9d3..05d0762 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -155,6 +155,8 @@ class Task(AttributedDocument): "project", ("company", "name"), ("company", "user"), + ("company", "status", "type"), + ("company", "system_tags", "last_update"), ("company", "type", "system_tags", "status"), ("company", "project", "type", "system_tags", "status"), ("status", "last_update"), # for maintenance tasks @@ -215,6 +217,7 @@ class Task(AttributedDocument): last_worker = StringField() last_worker_report = DateTimeField() last_update = DateTimeField() + last_change = DateTimeField() last_iteration = IntField(default=DEFAULT_LAST_ITERATION) last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent))) metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index e0e1143..4c18db9 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -501,6 +501,11 @@ _definitions { type: string format: "date-time" } + last_change { + description: "Last time any update was done to the task" + type: string + format: "date-time" + } last_iteration { description: "Last iteration reported for this task" type: integer diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 287c004..40f6ef2 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -472,9 +472,11 @@ def update(call: APICall, company_id, _): raise errors.bad_request.ModelCreatingTaskExists( "and published, use force=True to delete", task=model.task ) + now = datetime.utcnow() task.update( output__model=deleted_model_id, - output__error=f"model deleted on {datetime.utcnow().isoformat()}", + output__error=f"model deleted on {now.isoformat()}", + last_change=now, upsert=False, ) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 1daa8f2..f015ee0 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -69,6 +69,7 @@ from apiserver.bll.task.param_utils import ( params_unprepare_from_saved, escape_paths, ) +from apiserver.bll.task.utils import update_task from apiserver.bll.util import SetFieldsResolver from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility @@ -439,7 +440,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest): def prepare_update_fields(call: APICall, task, call_data): - valid_fields = deepcopy(task.__class__.user_set_allowed()) + valid_fields = deepcopy(Task.user_set_allowed()) update_fields = {k: v for k, v in create_fields.items() if k in valid_fields} update_fields["output__error"] = None t_fields = task_fields @@ -467,7 +468,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest): return UpdateResponse(updated=0) updated_count, updated_fields = Task.safe_update( - company_id=company_id, id=task_id, partial_update_dict=partial_update_dict, + company_id=company_id, + id=task_id, + partial_update_dict=partial_update_dict, + injected_update=dict(last_change=datetime.utcnow()), ) if updated_count: new_project = updated_fields.get("project", task.project) @@ -500,8 +504,8 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques raise errors.bad_request.MissingTaskFields( "Task has no script field", task=task.id ) - res = task.update( - script__requirements=requirements, last_update=datetime.utcnow() + res = update_task( + task, update_cmds=dict(script__requirements=requirements) ) call.result.data_model = UpdateResponse(updated=res) if res: @@ -537,7 +541,7 @@ def update_batch(call: APICall, company_id, _): partial_update_dict = Task.get_safe_update_dict(fields) if not partial_update_dict: continue - partial_update_dict.update(last_update=now) + partial_update_dict.update(last_change=now) update_op = UpdateOne( {"_id": id, "company": company_id}, {"$set": partial_update_dict} ) @@ -608,8 +612,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): } if fixed_fields: now = datetime.utcnow() - fields.update(last_update=now) - fixed_fields.update(last_update=now) + last_change = dict(last_change=now) + if not set(fields).issubset(Task.user_set_allowed()): + last_change.update(last_update=now) + fields.update(**last_change) + fixed_fields.update(**last_change) updated = task.update(upsert=False, **fixed_fields) if updated: new_project = fixed_fields.get("project", task.project) @@ -920,6 +927,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): system_tags=sorted( set(task.system_tags) | {EntityVisibility.archived.value} ), + last_change=datetime.utcnow(), ) archived += 1 diff --git a/apiserver/tests/automated/test_task_hyperparams.py b/apiserver/tests/automated/test_task_hyperparams.py index 7ecb1c9..67f0376 100644 --- a/apiserver/tests/automated/test_task_hyperparams.py +++ b/apiserver/tests/automated/test_task_hyperparams.py @@ -105,7 +105,7 @@ class TestTasksHyperparams(TestService): ) # clone task - new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id + new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id try: res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0] self.assertEqual(new_params, res.hyperparams) @@ -223,7 +223,7 @@ class TestTasksHyperparams(TestService): self.assertEqual(old_config + new_config[:1], res.configuration) # clone task - new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id + new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id try: res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0] self.assertEqual(new_config, res.configuration)