Add last_change and last_change_by DB Model

This commit is contained in:
allegroai 2023-11-17 09:35:22 +02:00
parent 6d507616b3
commit d62ecb5e6e
15 changed files with 231 additions and 67 deletions

View File

@ -130,7 +130,11 @@ class EventBLL(object):
return res
def add_events(
self, company_id, events, worker
self,
company_id: str,
user_id: str,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
task_ids = {}
model_ids = {}
@ -311,21 +315,24 @@ class EventBLL(object):
else:
errors_per_type["Error when indexing events batch"] += 1
now = datetime.utcnow()
for model_id in used_model_ids:
ModelBLL.update_statistics(
company_id=company_id,
user_id=user_id,
model_id=model_id,
last_update=now,
last_iteration_max=task_iteration.get(model_id),
last_scalar_events=task_last_scalar_events.get(model_id),
)
remaining_tasks = set()
now = datetime.utcnow()
for task_id in used_task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
user_id=user_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
@ -336,7 +343,12 @@ class EventBLL(object):
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
TaskBLL.set_last_update(
remaining_tasks,
company_id=company_id,
user_id=user_id,
last_update=now,
)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
@ -466,9 +478,10 @@ class EventBLL(object):
def _update_task(
self,
company_id,
task_id,
now,
company_id: str,
user_id: str,
task_id: str,
now: datetime,
iter_max=None,
last_scalar_events=None,
last_events=None,
@ -484,8 +497,9 @@ class EventBLL(object):
return False
return TaskBLL.update_statistics(
task_id,
company_id,
task_id=task_id,
company_id=company_id,
user_id=user_id,
last_update=now,
last_iteration_max=iter_max,
last_scalar_events=last_scalar_events,

View File

@ -80,7 +80,14 @@ class ModelBLL:
id=model.task, data=task_publish_res
)
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
now = datetime.utcnow()
updated = model.update(
upsert=False,
ready=True,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
return updated, published_task
@classmethod
@ -125,6 +132,7 @@ class ModelBLL:
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
"last_changed_by": user_id,
},
},
array_filters=[{"elem.model": model_id}],
@ -132,7 +140,9 @@ class ModelBLL:
)
else:
task.update(
pull__models__output__model=model_id, set__last_change=now
pull__models__output__model=model_id,
set__last_change=now,
set__last_changed_by=user_id,
)
delete_external_artifacts = delete_external_artifacts and config.get(
@ -167,25 +177,29 @@ class ModelBLL:
return del_count, model
@classmethod
def archive_model(cls, model_id: str, company_id: str):
def archive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
archived = Model.objects(company=company_id, id=model_id).update(
add_to_set__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return archived
@classmethod
def unarchive_model(cls, model_id: str, company_id: str):
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
cls.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
unarchived = Model.objects(company=company_id, id=model_id).update(
pull__system_tags=EntityVisibility.archived.value,
last_update=datetime.utcnow(),
last_change=now,
last_changed_by=user_id,
)
return unarchived
@ -218,11 +232,18 @@ class ModelBLL:
@staticmethod
def update_statistics(
company_id: str,
user_id: str,
model_id: str,
last_update: datetime = None,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
updates = {"last_update": datetime.utcnow()}
last_update = last_update or datetime.utcnow()
updates = {
"last_update": datetime.utcnow(),
"last_change": last_update,
"last_changed_by": user_id,
}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)

View File

@ -315,11 +315,12 @@ class ProjectBLL:
description="",
)
extra = (
{"set__last_change": datetime.utcnow()}
if hasattr(entity_cls, "last_change")
else {}
)
extra = {}
if hasattr(entity_cls, "last_change"):
extra["set__last_change"] = datetime.utcnow()
if hasattr(entity_cls, "last_changed_by"):
extra["set__last_changed_by"] = user
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)

View File

@ -185,10 +185,10 @@ def delete_project(
res = DeleteProjectResult(disassociated_tasks=disassociated[Task])
else:
deleted_models, model_event_urls, model_urls = _delete_models(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
deleted_tasks, task_event_urls, artifact_urls = _delete_tasks(
company=company, projects=project_ids
company=company, user=user, projects=project_ids
)
event_urls = task_event_urls | model_event_urls
if delete_external_artifacts:
@ -217,7 +217,7 @@ def delete_project(
return res, affected
def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
def _delete_tasks(company: str, user: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
"""
Delete only the task themselves and their non published version.
Child models under the same project are deleted separately.
@ -229,8 +229,17 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
return 0, set(), set()
task_ids = {t.id for t in tasks}
Task.objects(parent__in=task_ids, project__nin=projects).update(parent=None)
Model.objects(task__in=task_ids, project__nin=projects).update(task=None)
now = datetime.utcnow()
Task.objects(parent__in=task_ids, project__nin=projects).update(
parent=None,
last_change=now,
last_changed_by=user,
)
Model.objects(task__in=task_ids, project__nin=projects).update(
task=None,
last_change=now,
last_changed_by=user,
)
event_urls, artifact_urls = set(), set()
for task in tasks:
@ -253,7 +262,7 @@ def _delete_tasks(company: str, projects: Sequence[str]) -> Tuple[int, Set, Set]
def _delete_models(
company: str, projects: Sequence[str]
company: str, user: str, projects: Sequence[str]
) -> Tuple[int, Set[str], Set[str]]:
"""
Delete project models and update the tasks from other projects
@ -287,7 +296,11 @@ def _delete_models(
"status": TaskStatus.published,
},
update={
"$set": {"models.output.$[elem].model": deleted, "last_change": now,}
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
"last_changed_by": user,
}
},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
@ -295,7 +308,11 @@ def _delete_models(
# update unpublished tasks
Task.objects(
id__in=model_tasks, project__nin=projects, status__ne=TaskStatus.published,
).update(pull__models__output__model__in=model_ids, set__last_change=now)
).update(
pull__models__output__model__in=model_ids,
set__last_change=now,
set__last_changed_by=user,
)
event_urls, model_urls = set(), set()
for m in models:

View File

@ -85,6 +85,7 @@ class NonResponsiveTasksWatchdog:
status_changed=now,
last_update=now,
last_change=now,
last_changed_by="__apiserver__",
)
if updated:
project_ids.add(task.project)

View File

@ -356,6 +356,7 @@ class TaskBLL:
def set_last_update(
task_ids: Collection[str],
company_id: str,
user_id: str,
last_update: datetime,
**extra_updates,
):
@ -376,6 +377,7 @@ class TaskBLL:
upsert=False,
last_update=last_update,
last_change=last_update,
last_changed_by=user_id,
**updates,
)
return count
@ -384,6 +386,7 @@ class TaskBLL:
def update_statistics(
task_id: str,
company_id: str,
user_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
@ -440,6 +443,7 @@ class TaskBLL:
ret = TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
user_id=user_id,
last_update=last_update,
**extra_updates,
)

View File

@ -222,8 +222,13 @@ def cleanup_task(
deleted_task_id = f"{deleted_prefix}{task.id}"
updated_children = 0
now = datetime.utcnow()
if update_children:
updated_children = Task.objects(parent=task.id).update(parent=deleted_task_id)
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id,
last_change=now,
last_changed_by=user,
)
deleted_models = 0
updated_models = 0
@ -249,15 +254,25 @@ def cleanup_task(
deleted_models += Model.objects(id__in=list(model_ids)).delete()
if in_use_model_ids:
Model.objects(id__in=list(in_use_model_ids)).update(unset__task=1)
Model.objects(id__in=list(in_use_model_ids)).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
continue
if update_children:
updated_models += Model.objects(id__in=[m.id for m in models]).update(
task=deleted_task_id
task=deleted_task_id,
last_change=now,
last_changed_by=user,
)
else:
Model.objects(id__in=[m.id for m in models]).update(unset__task=1)
Model.objects(id__in=[m.id for m in models]).update(
unset__task=1,
set__last_change=now,
set__last_changed_by=user,
)
event_bll.delete_task_events(
task.company, task.id, allow_locked=force, async_delete=async_events_delete

View File

@ -175,6 +175,7 @@ class WorkerBLL:
last_worker_report=now,
last_update=now,
last_change=now,
last_changed_by=user_id,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)

View File

@ -1,5 +1,6 @@
import re
from collections import namedtuple, defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
Collection,
@ -147,8 +148,7 @@ class GetMixin(PropsMixin):
}
default_operator = Q.OR
mongo_modifiers = {
# not_all modifier currently not supported due to the backwards compatibility
Q.AND: {True: "all", False: "nin"},
Q.AND: {True: "all", False: "not__all"},
Q.OR: {True: "in", False: "nin"},
}
@ -1234,25 +1234,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
def set_public(
cls: Type[Document],
company_id: str,
user_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, set__company="")
update: dict = dict(set__company_origin=company_id, set__company="")
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
update: dict = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
if hasattr(cls, "last_change"):
update["set__last_change"] = datetime.utcnow()
if hasattr(cls, "last_changed_by"):
update["set__last_changed_by"] = user_id
return {"updated": cls.objects(id__in=ids).update(**update)}

View File

@ -90,6 +90,8 @@ class Model(AttributedDocument):
labels = ModelLabels()
ready = BooleanField(required=True)
last_update = DateTimeField()
last_change = DateTimeField()
last_changed_by = StringField()
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)

View File

@ -960,7 +960,7 @@ class PrePopulate:
return tasks
@classmethod
def _import_events(cls, f: IO[bytes], company_id: str, _, task_id: str):
def _import_events(cls, f: IO[bytes], company_id: str, user_id: str, task_id: str):
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
@ -969,5 +969,8 @@ class PrePopulate:
ev["company_id"] = company_id
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id, events=events, worker=""
company_id=company_id,
user_id=user_id,
events=events,
worker="",
)

View File

@ -71,7 +71,12 @@ def _assert_task_or_model_exists(
@endpoint("events.add")
def add(call: APICall, company_id, _):
data = call.data.copy()
added, err_count, err_info = event_bll.add_events(company_id, [data], call.worker)
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
events=[data],
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -82,9 +87,10 @@ def add_batch(call: APICall, company_id, _):
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
company_id=company_id,
user_id=call.identity.user,
events=events,
worker=call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)

View File

@ -76,7 +76,9 @@ def get_by_id(call: APICall, company_id, _):
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
"no such public or company model",
id=model_id,
company=company_id,
)
conform_model_data(call, models[0])
call.result.data = {"model": models[0]}
@ -102,7 +104,9 @@ def get_by_task_id(call: APICall, company_id, _):
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
"no such public or company model",
id=model_id,
company=company_id,
)
model_dict = model.to_proper_dict()
conform_model_data(call, model_dict)
@ -128,7 +132,10 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(company=company_id, model_ids=list(model_ids),)
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
for model in models:
model["stats"] = stats.get(model["id"])
@ -220,7 +227,9 @@ def _update_cached_tags(company: str, project: str, fields: dict):
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company, Tags.Model, projects=projects,
company,
Tags.Model,
projects=projects,
)
@ -283,6 +292,8 @@ def update_for_task(call: APICall, company_id, _):
id=database.utils.id(),
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
user=call.identity.user,
company=company_id,
project=task.project,
@ -301,6 +312,7 @@ def update_for_task(call: APICall, company_id, _):
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
models__output=[
ModelItem(
@ -320,7 +332,6 @@ def update_for_task(call: APICall, company_id, _):
response_data_model=CreateModelResponse,
)
def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company_id = ""
@ -345,6 +356,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
company=company_id,
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
**fields,
)
model.save()
@ -414,12 +427,20 @@ def edit(call: APICall, company_id, _):
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
if fields:
now = datetime.utcnow()
fields.update(
last_change=now,
last_changed_by=call.identity.user,
)
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
fields.update(last_update=now)
updated = model.update(upsert=False, **fields)
if updated:
@ -445,13 +466,25 @@ def _update_model(call: APICall, company_id, model_id=None):
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
now = datetime.utcnow()
updated_count, updated_fields = Model.safe_update(
company_id,
model.id,
data,
injected_update=dict(
last_change=now,
last_changed_by=call.identity.user,
),
)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
model.update(upsert=False, last_update=now)
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
@ -573,7 +606,10 @@ def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
)
def archive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
func=partial(
ModelBLL.archive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
@ -588,7 +624,8 @@ def archive_many(call: APICall, company_id, request: BatchRequest):
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
results, failures = run_batch_operation(
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
func=partial(ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user),
ids=request.ids,
)
call.result.data_model = BatchResponse(
succeeded=[
@ -601,7 +638,11 @@ def unarchive_many(call: APICall, company_id, request: BatchRequest):
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=True,
)
@ -610,7 +651,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=False,
)
@ -635,28 +680,36 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
now = datetime.utcnow()
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=datetime.utcnow(),
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
return {
"updated": Metadata.delete_metadata(
model, keys=request.keys, last_update=datetime.utcnow()
model,
keys=request.keys,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}

View File

@ -506,7 +506,11 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidProjectId,
enabled=True,
)
@ -515,7 +519,11 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidProjectId,
enabled=False,
)

View File

@ -1232,9 +1232,12 @@ def completed(call: APICall, company_id, request: CompletedRequest):
@endpoint("tasks.ping", request_data_model=PingRequest)
def ping(_, company_id, request: PingRequest):
def ping(call: APICall, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
task_ids=[request.task],
company_id=company_id,
user_id=call.identity.user,
last_update=datetime.utcnow(),
)
@ -1277,14 +1280,22 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidTaskId,
enabled=True,
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidTaskId,
enabled=False,
)
@ -1315,7 +1326,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
models_field = f"models__{request.type}"
@ -1326,6 +1337,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
updated = TaskBLL.update_statistics(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=request.iteration,
**({f"push__{models_field}": model} if not updated else {}),
)