Fix move task to trash is not thread-safe

This commit is contained in:
allegroai 2022-05-18 10:31:20 +03:00
parent e0cde2f7c9
commit 710443b078
2 changed files with 40 additions and 20 deletions

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Callable, Any, Tuple, Union from typing import Callable, Any, Tuple, Union, Sequence
from apiserver.apierrors import errors, APIError from apiserver.apierrors import errors, APIError
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
@ -25,6 +25,7 @@ from apiserver.database.model.task.task import (
) )
from apiserver.utilities.dicts import nested_set from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
queue_bll = QueueBLL() queue_bll = QueueBLL()
@ -83,10 +84,7 @@ def unarchive_task(
def dequeue_task( def dequeue_task(
task_id: str, task_id: str, company_id: str, status_message: str, status_reason: str,
company_id: str,
status_message: str,
status_reason: str,
) -> Tuple[int, dict]: ) -> Tuple[int, dict]:
query = dict(id=task_id, company=company_id) query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query) task = Task.get_for_writing(**query)
@ -94,10 +92,7 @@ def dequeue_task(
raise errors.bad_request.InvalidTaskId(**query) raise errors.bad_request.InvalidTaskId(**query)
res = TaskBLL.dequeue_and_change_status( res = TaskBLL.dequeue_and_change_status(
task, task, company_id, status_message=status_message, status_reason=status_reason,
company_id,
status_message=status_message,
status_reason=status_reason,
) )
return 1, res return 1, res
@ -169,6 +164,30 @@ def enqueue_task(
return 1, res return 1, res
def move_tasks_to_trash(tasks: Sequence[str]) -> int:
try:
collection_name = Task._get_collection_name()
trash_collection_name = f"{collection_name}__trash"
Task.aggregate(
[
{"$match": {"_id": {"$in": tasks}}},
{
"$merge": {
"into": trash_collection_name,
"on": "_id",
"whenMatched": "replace",
"whenNotMatched": "insert",
}
},
],
allow_disk_use=True,
)
except Exception as ex:
log.error(f"Error copying tasks to trash {str(ex)}")
return Task.objects(id__in=tasks).delete()
def delete_task( def delete_task(
task_id: str, task_id: str,
company_id: str, company_id: str,
@ -214,18 +233,12 @@ def delete_task(
) )
if move_to_trash: if move_to_trash:
collection_name = task._get_collection_name() # make sure that whatever changes were done to the task are saved
archived_collection = "{}__trash".format(collection_name) # the task itself will be deleted later in the move_tasks_to_trash operation
task.switch_collection(archived_collection) task.save()
try: else:
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force task.delete()
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
task.save(force_insert=True)
except Exception:
pass
task.switch_collection(collection_name)
task.delete()
update_project_time(task.project) update_project_time(task.project)
return 1, task, cleanup_res return 1, task, cleanup_res

View File

@ -94,6 +94,7 @@ from apiserver.bll.task.task_operations import (
delete_task, delete_task,
publish_task, publish_task,
unarchive_task, unarchive_task,
move_tasks_to_trash,
) )
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation from apiserver.bll.util import SetFieldsResolver, run_batch_operation
@ -1075,6 +1076,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_reason=request.status_reason, status_reason=request.status_reason,
) )
if deleted: if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else []) _reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@ -1096,6 +1099,10 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
) )
if results: if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results) projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects)) _reset_cached_tags(company_id, projects=list(projects))