mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
Fix task and model last_change handling
Improve db model index Improve db model infrastructure
This commit is contained in:
parent
29c792d459
commit
59994ccf9c
@ -82,7 +82,7 @@ class DictField(fields.BaseField):
|
||||
"""Cast value to proper collection."""
|
||||
result = self.get_default_value()
|
||||
|
||||
if not values:
|
||||
if values is None:
|
||||
return result
|
||||
|
||||
if not self.value_types or not isinstance(values, dict):
|
||||
|
@ -127,7 +127,11 @@ class ProjectBLL:
|
||||
project_name=project_name,
|
||||
description="Auto-generated during move",
|
||||
)
|
||||
|
||||
entity_cls.objects(company=company, id__in=ids).update(set__project=project)
|
||||
extra = (
|
||||
{"set__last_change": datetime.utcnow()}
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
|
||||
|
||||
return project
|
||||
|
@ -1,10 +1,9 @@
|
||||
from datetime import datetime
|
||||
from hashlib import md5
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
@ -70,7 +69,7 @@ class Artifacts:
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
@ -95,4 +94,4 @@ class Artifacts:
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
|
@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Dict
|
||||
@ -13,7 +12,7 @@ from apiserver.apimodels.tasks import (
|
||||
Configuration,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.timing_context import TimingContext
|
||||
@ -96,7 +95,9 @@ class HyperParams:
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
return update_task(
|
||||
task, update_cmds=delete_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
@ -132,7 +133,9 @@ class HyperParams:
|
||||
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
|
||||
] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
return update_task(
|
||||
task, update_cmds=update_cmds, set_last_update=not properties_only
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
@ -223,7 +226,7 @@ class HyperParams:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
return update_task(task, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
@ -239,4 +242,4 @@ class HyperParams:
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
return update_task(task, update_cmds=delete_cmds)
|
||||
|
@ -145,6 +145,7 @@ class TaskBLL:
|
||||
company=identity.company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
**fields,
|
||||
)
|
||||
|
||||
@ -237,6 +238,7 @@ class TaskBLL:
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
@ -367,7 +369,10 @@ class TaskBLL:
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **updates
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
**updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -653,11 +658,7 @@ class TaskBLL:
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls,
|
||||
task: Task,
|
||||
company_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
):
|
||||
cls.dequeue(task, company_id)
|
||||
|
||||
|
@ -43,6 +43,7 @@ class ChangeStatusRequest(object):
|
||||
status_message=self.status_message,
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
|
||||
if self.new_status == TaskStatus.queued:
|
||||
@ -194,3 +195,11 @@ def get_task_for_update(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
def update_task(task: Task, update_cmds: dict, set_last_update: bool = True):
|
||||
now = datetime.utcnow()
|
||||
last_updates = dict(last_change=now)
|
||||
if set_last_update:
|
||||
last_updates.update(last_update=now)
|
||||
return task.update(**update_cmds, **last_updates)
|
||||
|
@ -164,6 +164,7 @@ class WorkerBLL:
|
||||
last_worker=report.worker,
|
||||
last_worker_report=now,
|
||||
last_update=now,
|
||||
last_change=now,
|
||||
)
|
||||
# modify(new=True, ...) returns the modified object
|
||||
task = Task.objects(**query).modify(new=True, **update)
|
||||
|
@ -52,7 +52,7 @@ class Credentials(EmbeddedDocument):
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
meta = {"db_alias": Database.auth, "strict": strict}
|
||||
meta = {"db_alias": Database.auth, "strict": strict, "indexes": ["email"]}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StringField()
|
||||
|
@ -701,14 +701,24 @@ class GetMixin(PropsMixin):
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
__user_set_allowed_fields = None
|
||||
__locked_when_published_fields = None
|
||||
|
||||
@classmethod
|
||||
def user_set_allowed(cls):
|
||||
res = getattr(cls, "__user_set_allowed_fields", None)
|
||||
if res is None:
|
||||
res = cls.__user_set_allowed_fields = get_fields_choices(
|
||||
cls, "user_set_allowed"
|
||||
if cls.__user_set_allowed_fields is None:
|
||||
cls.__user_set_allowed_fields = dict(
|
||||
get_fields_choices(cls, "user_set_allowed")
|
||||
)
|
||||
return res
|
||||
return cls.__user_set_allowed_fields
|
||||
|
||||
@classmethod
|
||||
def locked_when_published(cls):
|
||||
if cls.__locked_when_published_fields is None:
|
||||
cls.__locked_when_published_fields = dict(
|
||||
get_fields_choices(cls, "locked_when_published")
|
||||
)
|
||||
return cls.__locked_when_published_fields
|
||||
|
||||
@classmethod
|
||||
def get_safe_update_dict(cls, fields):
|
||||
|
@ -155,6 +155,8 @@ class Task(AttributedDocument):
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "status", "type"),
|
||||
("company", "system_tags", "last_update"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
@ -215,6 +217,7 @@ class Task(AttributedDocument):
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_change = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
|
@ -501,6 +501,11 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_change {
|
||||
description: "Last time any update was done to the task"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_iteration {
|
||||
description: "Last iteration reported for this task"
|
||||
type: integer
|
||||
|
@ -472,9 +472,11 @@ def update(call: APICall, company_id, _):
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
now = datetime.utcnow()
|
||||
task.update(
|
||||
output__model=deleted_model_id,
|
||||
output__error=f"model deleted on {datetime.utcnow().isoformat()}",
|
||||
output__error=f"model deleted on {now.isoformat()}",
|
||||
last_change=now,
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
|
@ -69,6 +69,7 @@ from apiserver.bll.task.param_utils import (
|
||||
params_unprepare_from_saved,
|
||||
escape_paths,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task
|
||||
from apiserver.bll.util import SetFieldsResolver
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
@ -439,7 +440,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
|
||||
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
valid_fields = deepcopy(task.__class__.user_set_allowed())
|
||||
valid_fields = deepcopy(Task.user_set_allowed())
|
||||
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
||||
update_fields["output__error"] = None
|
||||
t_fields = task_fields
|
||||
@ -467,7 +468,10 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
return UpdateResponse(updated=0)
|
||||
|
||||
updated_count, updated_fields = Task.safe_update(
|
||||
company_id=company_id, id=task_id, partial_update_dict=partial_update_dict,
|
||||
company_id=company_id,
|
||||
id=task_id,
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(last_change=datetime.utcnow()),
|
||||
)
|
||||
if updated_count:
|
||||
new_project = updated_fields.get("project", task.project)
|
||||
@ -500,8 +504,8 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
|
||||
raise errors.bad_request.MissingTaskFields(
|
||||
"Task has no script field", task=task.id
|
||||
)
|
||||
res = task.update(
|
||||
script__requirements=requirements, last_update=datetime.utcnow()
|
||||
res = update_task(
|
||||
task, update_cmds=dict(script__requirements=requirements)
|
||||
)
|
||||
call.result.data_model = UpdateResponse(updated=res)
|
||||
if res:
|
||||
@ -537,7 +541,7 @@ def update_batch(call: APICall, company_id, _):
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
partial_update_dict.update(last_update=now)
|
||||
partial_update_dict.update(last_change=now)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
)
|
||||
@ -608,8 +612,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
}
|
||||
if fixed_fields:
|
||||
now = datetime.utcnow()
|
||||
fields.update(last_update=now)
|
||||
fixed_fields.update(last_update=now)
|
||||
last_change = dict(last_change=now)
|
||||
if not set(fields).issubset(Task.user_set_allowed()):
|
||||
last_change.update(last_update=now)
|
||||
fields.update(**last_change)
|
||||
fixed_fields.update(**last_change)
|
||||
updated = task.update(upsert=False, **fixed_fields)
|
||||
if updated:
|
||||
new_project = fixed_fields.get("project", task.project)
|
||||
@ -920,6 +927,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
system_tags=sorted(
|
||||
set(task.system_tags) | {EntityVisibility.archived.value}
|
||||
),
|
||||
last_change=datetime.utcnow(),
|
||||
)
|
||||
|
||||
archived += 1
|
||||
|
@ -105,7 +105,7 @@ class TestTasksHyperparams(TestService):
|
||||
)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id
|
||||
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id
|
||||
try:
|
||||
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
|
||||
self.assertEqual(new_params, res.hyperparams)
|
||||
@ -223,7 +223,7 @@ class TestTasksHyperparams(TestService):
|
||||
self.assertEqual(old_config + new_config[:1], res.configuration)
|
||||
|
||||
# clone task
|
||||
new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id
|
||||
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id
|
||||
try:
|
||||
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
|
||||
self.assertEqual(new_config, res.configuration)
|
||||
|
Loading…
Reference in New Issue
Block a user