Task move forward/backwards in queue is now atomic

This commit is contained in:
allegroai 2023-05-25 19:16:33 +03:00
parent 5c5d9b6434
commit 2e4e060a82
3 changed files with 177 additions and 43 deletions

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple from typing import Sequence, Optional, Tuple, Union
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from mongoengine import Q from mongoengine import Q
@ -16,6 +16,8 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry from apiserver.database.model.queue import Queue, Entry
log = config.logger(__file__) log = config.logger(__file__)
MOVE_FIRST = "first"
MOVE_LAST = "last"
class QueueBLL(object): class QueueBLL(object):
@ -323,43 +325,131 @@ class QueueBLL(object):
company_id: str, company_id: str,
queue_id: str, queue_id: str,
task_id: str, task_id: str,
pos_func: Callable[[int], int], move_count: Union[int, str],
) -> int: ) -> int:
""" """
Moves the task in the queue to the position calculated by pos_func Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue Returns the updated task position in the queue
""" """
with translate_errors_context(): def get_queue_and_task_position():
queue = self.get_queue_with_task( q = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id company_id=company_id, queue_id=queue_id, task_id=task_id
) )
return q, next(i for i, e in enumerate(q.entries) if e.task == task_id)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id) with translate_errors_context():
new_position = pos_func(position) queue, position = get_queue_and_task_position()
if move_count == MOVE_FIRST:
if new_position != position: new_position = 0
entry = queue.entries[position] elif move_count == MOVE_LAST:
query = dict(id=queue_id, company=company_id) new_position = len(queue.entries) - 1
updated = Queue.objects(entries__task=task_id, **query).update_one( else:
pull__entries=entry, last_update=datetime.utcnow() new_position = position + move_count
) if new_position == position:
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
return new_position return new_position
without_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$ne": ["$$entry.task", task_id]},
}
}
task_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$eq": ["$$entry.task", task_id]},
}
}
if move_count == MOVE_FIRST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [task_entry, without_entry]}
}
}
]
elif move_count == MOVE_LAST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [without_entry, task_entry]}
}
}
]
else:
operations = [
{
"$set": {
"new_pos": {
"$add": [
{"$indexOfArray": ["$entries.task", task_id]},
move_count,
]
},
"without_entry": without_entry,
"task_entry": task_entry,
}
},
{
"$set": {
"entries": {
"$switch": {
"branches": [
{
"case": {"$lte": ["$new_pos", 0]},
"then": {
"$concatArrays": [
"$task_entry",
"$without_entry",
]
},
},
{
"case": {
"$gte": [
"$new_pos",
{"$size": "$without_entry"},
]
},
"then": {
"$concatArrays": [
"$without_entry",
"$task_entry",
]
},
},
],
"default": {
"$concatArrays": [
{"$slice": ["$without_entry", "$new_pos"]},
"$task_entry",
{
"$slice": [
"$without_entry",
"$new_pos",
{"$size": "$without_entry"},
]
},
]
},
}
}
}
},
{"$unset": ["new_pos", "without_entry", "task_entry"]},
]
updated = Queue.objects(
id=queue_id, company=company_id, entries__task=task_id
).update_one(__raw__=operations)
if not updated:
raise errors.bad_request.FailedAddingDuringReposition(task=task_id)
return get_queue_and_task_position()[1]
def count_entries(self, company: str, queue_id: str) -> Optional[int]: def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next( res = next(
Queue.aggregate( Queue.aggregate(

View File

@ -21,6 +21,7 @@ from apiserver.apimodels.queues import (
) )
from apiserver.bll.model import Metadata from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.queue.queue_bll import MOVE_FIRST, MOVE_LAST
from apiserver.bll.workers import WorkerBLL from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task
@ -195,7 +196,7 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=req_model.queue,
task_id=req_model.task, task_id=req_model.task,
pos_func=lambda p: max(0, p - req_model.count), move_count=-req_model.count,
) )
) )
@ -212,7 +213,7 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=req_model.queue,
task_id=req_model.task, task_id=req_model.task,
pos_func=lambda p: max(0, p + req_model.count), move_count=req_model.count,
) )
) )
@ -229,7 +230,7 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=req_model.queue,
task_id=req_model.task, task_id=req_model.task,
pos_func=lambda p: 0, move_count=MOVE_FIRST,
) )
) )
@ -246,7 +247,7 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=req_model.queue,
task_id=req_model.task, task_id=req_model.task,
pos_func=lambda p: -1, move_count=MOVE_LAST,
) )
) )

View File

@ -138,27 +138,70 @@ class TestQueues(TestService):
queue = self._temp_queue("TestTempQueue") queue = self._temp_queue("TestTempQueue")
tasks = [ tasks = [
self._create_temp_queued_task(t, queue)["id"] self._create_temp_queued_task(t, queue)["id"]
for t in ("temp task1", "temp task2", "temp task3") for t in ("temp task1", "temp task2", "temp task3", "temp task4")
] ]
res = self.api.queues.get_by_id(queue=queue) res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks) self.assertQueueTasks(res.queue, tasks)
new_pos = self.api.queues.move_task_backward( # no change in position
queue=queue, task=tasks[0], count=2 new_pos = self.api.queues.move_task_to_front(
).position queue=queue, task=tasks[0]
self.assertEqual(new_pos, 2)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + tasks[:1]
self.assertQueueTasks(res.queue, changed_tasks)
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=2
).position ).position
self.assertEqual(new_pos, 0) self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue) res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks) self.assertQueueTasks(res.queue, tasks)
self.assertGetNextTasks(queue, tasks) # move backwards in the middle
new_pos = self.api.queues.move_task_backward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 2)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:3] + [tasks[0], tasks[3]]
self.assertQueueTasks(res.queue, changed_tasks)
# move backwards beyond the end
new_pos = self.api.queues.move_task_backward(
queue=queue, task=tasks[0], count=100
).position
self.assertEqual(new_pos, 3)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + [tasks[0]]
self.assertQueueTasks(res.queue, changed_tasks)
# move forwards in the middle
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 1)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = [tasks[1], tasks[0]] + tasks[2:]
self.assertQueueTasks(res.queue, changed_tasks)
# move forwards beyond the beginning
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=100
).position
self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
# move to back
new_pos = self.api.queues.move_task_to_back(
queue=queue, task=tasks[0]
).position
self.assertEqual(new_pos, 3)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + [tasks[0]]
self.assertQueueTasks(res.queue, changed_tasks)
# move to front
new_pos = self.api.queues.move_task_to_front(
queue=queue, task=tasks[0]
).position
self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
def test_get_all_ex(self): def test_get_all_ex(self):
queue_name = "TestTempQueue1" queue_name = "TestTempQueue1"