Fix task and model last_change handling

Improve db model index
Improve db model infrastructure
This commit is contained in:
allegroai 2021-01-05 18:17:29 +02:00
parent 29c792d459
commit 59994ccf9c
14 changed files with 80 additions and 35 deletions

View File

@ -82,7 +82,7 @@ class DictField(fields.BaseField):
"""Cast value to proper collection."""
result = self.get_default_value()
if not values:
if values is None:
return result
if not self.value_types or not isinstance(values, dict):

View File

@ -127,7 +127,11 @@ class ProjectBLL:
project_name=project_name,
description="Auto-generated during move",
)
entity_cls.objects(company=company, id__in=ids).update(set__project=project)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
return project

View File

@ -1,10 +1,9 @@
from datetime import datetime
from hashlib import md5
from operator import itemgetter
from typing import Sequence
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
@ -70,7 +69,7 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return task.update(**update_cmds, last_update=datetime.utcnow())
return update_task(task, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
@ -95,4 +94,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())
return update_task(task, update_cmds=delete_cmds)

View File

@ -1,4 +1,3 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict
@ -13,7 +12,7 @@ from apiserver.apimodels.tasks import (
Configuration,
)
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
@ -96,7 +95,9 @@ class HyperParams:
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return task.update(**delete_cmds, last_update=datetime.utcnow())
return update_task(
task, update_cmds=delete_cmds, set_last_update=not properties_only
)
@classmethod
def edit_params(
@ -132,7 +133,9 @@ class HyperParams:
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
return update_task(
task, update_cmds=update_cmds, set_last_update=not properties_only
)
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
@ -223,7 +226,7 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
return update_task(task, update_cmds=update_cmds)
@classmethod
def delete_configuration(
@ -239,4 +242,4 @@ class HyperParams:
for name in set(configuration)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())
return update_task(task, update_cmds=delete_cmds)

View File

@ -145,6 +145,7 @@ class TaskBLL:
company=identity.company,
created=now,
last_update=now,
last_change=now,
**fields,
)
@ -237,6 +238,7 @@ class TaskBLL:
company=company_id,
created=now,
last_update=now,
last_change=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
@ -367,7 +369,10 @@ class TaskBLL:
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
upsert=False, last_update=last_update, **updates
upsert=False,
last_update=last_update,
last_change=last_update,
**updates,
)
@staticmethod
@ -653,11 +658,7 @@ class TaskBLL:
@classmethod
def dequeue_and_change_status(
cls,
task: Task,
company_id: str,
status_message: str,
status_reason: str,
cls, task: Task, company_id: str, status_message: str, status_reason: str,
):
cls.dequeue(task, company_id)

View File

@ -43,6 +43,7 @@ class ChangeStatusRequest(object):
status_message=self.status_message,
status_changed=now,
last_update=now,
last_change=now,
)
if self.new_status == TaskStatus.queued:
@ -194,3 +195,11 @@ def get_task_for_update(
expected=TaskStatus.created, status=task.status
)
return task
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
now = datetime.utcnow()
last_updates = dict(last_change=now)
if set_last_update:
last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates)

View File

@ -164,6 +164,7 @@ class WorkerBLL:
last_worker=report.worker,
last_worker_report=now,
last_update=now,
last_change=now,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)

View File

@ -52,7 +52,7 @@ class Credentials(EmbeddedDocument):
class User(DbModelMixin, AuthDocument):
meta = {"db_alias": Database.auth, "strict": strict}
meta = {"db_alias": Database.auth, "strict": strict, "indexes": ["email"]}
id = StringField(primary_key=True)
name = StringField()

View File

@ -701,14 +701,24 @@ class GetMixin(PropsMixin):
class UpdateMixin(object):
__user_set_allowed_fields = None
__locked_when_published_fields = None
@classmethod
def user_set_allowed(cls):
res = getattr(cls, "__user_set_allowed_fields", None)
if res is None:
res = cls.__user_set_allowed_fields = get_fields_choices(
cls, "user_set_allowed"
if cls.__user_set_allowed_fields is None:
cls.__user_set_allowed_fields = dict(
get_fields_choices(cls, "user_set_allowed")
)
return res
return cls.__user_set_allowed_fields
@classmethod
def locked_when_published(cls):
if cls.__locked_when_published_fields is None:
cls.__locked_when_published_fields = dict(
get_fields_choices(cls, "locked_when_published")
)
return cls.__locked_when_published_fields
@classmethod
def get_safe_update_dict(cls, fields):

View File

@ -155,6 +155,8 @@ class Task(AttributedDocument):
"project",
("company", "name"),
("company", "user"),
("company", "status", "type"),
("company", "system_tags", "last_update"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
@ -215,6 +217,7 @@ class Task(AttributedDocument):
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
last_change = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))

View File

@ -501,6 +501,11 @@ _definitions {
type: string
format: "date-time"
}
last_change {
description: "Last time any update was done to the task"
type: string
format: "date-time"
}
last_iteration {
description: "Last iteration reported for this task"
type: integer

View File

@ -472,9 +472,11 @@ def update(call: APICall, company_id, _):
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
now = datetime.utcnow()
task.update(
output__model=deleted_model_id,
output__error=f"model deleted on {datetime.utcnow().isoformat()}",
output__error=f"model deleted on {now.isoformat()}",
last_change=now,
upsert=False,
)

View File

@ -69,6 +69,7 @@ from apiserver.bll.task.param_utils import (
params_unprepare_from_saved,
escape_paths,
)
from apiserver.bll.task.utils import update_task
from apiserver.bll.util import SetFieldsResolver
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
@ -439,7 +440,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
def prepare_update_fields(call: APICall, task, call_data):
valid_fields = deepcopy(task.__class__.user_set_allowed())
valid_fields = deepcopy(Task.user_set_allowed())
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
update_fields["output__error"] = None
t_fields = task_fields
@ -467,7 +468,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
return UpdateResponse(updated=0)
updated_count, updated_fields = Task.safe_update(
company_id=company_id, id=task_id, partial_update_dict=partial_update_dict,
company_id=company_id,
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(last_change=datetime.utcnow()),
)
if updated_count:
new_project = updated_fields.get("project", task.project)
@ -500,8 +504,8 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
raise errors.bad_request.MissingTaskFields(
"Task has no script field", task=task.id
)
res = task.update(
script__requirements=requirements, last_update=datetime.utcnow()
res = update_task(
task, update_cmds=dict(script__requirements=requirements)
)
call.result.data_model = UpdateResponse(updated=res)
if res:
@ -537,7 +541,7 @@ def update_batch(call: APICall, company_id, _):
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
partial_update_dict.update(last_update=now)
partial_update_dict.update(last_change=now)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
)
@ -608,8 +612,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
}
if fixed_fields:
now = datetime.utcnow()
fields.update(last_update=now)
fixed_fields.update(last_update=now)
last_change = dict(last_change=now)
if not set(fields).issubset(Task.user_set_allowed()):
last_change.update(last_update=now)
fields.update(**last_change)
fixed_fields.update(**last_change)
updated = task.update(upsert=False, **fixed_fields)
if updated:
new_project = fixed_fields.get("project", task.project)
@ -920,6 +927,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
system_tags=sorted(
set(task.system_tags) | {EntityVisibility.archived.value}
),
last_change=datetime.utcnow(),
)
archived += 1

View File

@ -105,7 +105,7 @@ class TestTasksHyperparams(TestService):
)
# clone task
new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id
try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams)
@ -223,7 +223,7 @@ class TestTasksHyperparams(TestService):
self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task
new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id
try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration)