Task move forward/backwards in queue is now atomic

This commit is contained in:
allegroai 2023-05-25 19:16:33 +03:00
parent 5c5d9b6434
commit 2e4e060a82
3 changed files with 177 additions and 43 deletions

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from typing import Sequence, Optional, Tuple, Union
from elasticsearch import Elasticsearch
from mongoengine import Q
@ -16,6 +16,8 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.queue import Queue, Entry
log = config.logger(__file__)
MOVE_FIRST = "first"
MOVE_LAST = "last"
class QueueBLL(object):
@ -323,42 +325,130 @@ class QueueBLL(object):
company_id: str,
queue_id: str,
task_id: str,
pos_func: Callable[[int], int],
move_count: Union[int, str],
) -> int:
"""
Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
def get_queue_and_task_position():
q = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
return q, next(i for i, e in enumerate(q.entries) if e.task == task_id)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
new_position = pos_func(position)
with translate_errors_context():
queue, position = get_queue_and_task_position()
if move_count == MOVE_FIRST:
new_position = 0
elif move_count == MOVE_LAST:
new_position = len(queue.entries) - 1
else:
new_position = position + move_count
if new_position == position:
return new_position
if new_position != position:
entry = queue.entries[position]
query = dict(id=queue_id, company=company_id)
updated = Queue.objects(entries__task=task_id, **query).update_one(
pull__entries=entry, last_update=datetime.utcnow()
)
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
without_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$ne": ["$$entry.task", task_id]},
}
}
task_entry = {
"$filter": {
"input": "$entries",
"as": "entry",
"cond": {"$eq": ["$$entry.task", task_id]},
}
}
if move_count == MOVE_FIRST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [task_entry, without_entry]}
}
}
]
elif move_count == MOVE_LAST:
operations = [
{
"$set": {
"entries": {"$concatArrays": [without_entry, task_entry]}
}
}
]
else:
operations = [
{
"$set": {
"new_pos": {
"$add": [
{"$indexOfArray": ["$entries.task", task_id]},
move_count,
]
},
"without_entry": without_entry,
"task_entry": task_entry,
}
},
{
"$set": {
"entries": {
"$switch": {
"branches": [
{
"case": {"$lte": ["$new_pos", 0]},
"then": {
"$concatArrays": [
"$task_entry",
"$without_entry",
]
},
},
{
"case": {
"$gte": [
"$new_pos",
{"$size": "$without_entry"},
]
},
"then": {
"$concatArrays": [
"$without_entry",
"$task_entry",
]
},
},
],
"default": {
"$concatArrays": [
{"$slice": ["$without_entry", "$new_pos"]},
"$task_entry",
{
"$slice": [
"$without_entry",
"$new_pos",
{"$size": "$without_entry"},
]
},
]
},
}
}
}
},
{"$unset": ["new_pos", "without_entry", "task_entry"]},
]
return new_position
updated = Queue.objects(
id=queue_id, company=company_id, entries__task=task_id
).update_one(__raw__=operations)
if not updated:
raise errors.bad_request.FailedAddingDuringReposition(task=task_id)
return get_queue_and_task_position()[1]
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(

View File

@ -21,6 +21,7 @@ from apiserver.apimodels.queues import (
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.queue.queue_bll import MOVE_FIRST, MOVE_LAST
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
from apiserver.database.model.task.task import Task
@ -195,7 +196,7 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: max(0, p - req_model.count),
move_count=-req_model.count,
)
)
@ -212,7 +213,7 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: max(0, p + req_model.count),
move_count=req_model.count,
)
)
@ -229,7 +230,7 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: 0,
move_count=MOVE_FIRST,
)
)
@ -246,7 +247,7 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: -1,
move_count=MOVE_LAST,
)
)

View File

@ -138,27 +138,70 @@ class TestQueues(TestService):
queue = self._temp_queue("TestTempQueue")
tasks = [
self._create_temp_queued_task(t, queue)["id"]
for t in ("temp task1", "temp task2", "temp task3")
for t in ("temp task1", "temp task2", "temp task3", "temp task4")
]
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
new_pos = self.api.queues.move_task_backward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 2)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + tasks[:1]
self.assertQueueTasks(res.queue, changed_tasks)
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=2
# no change in position
new_pos = self.api.queues.move_task_to_front(
queue=queue, task=tasks[0]
).position
self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
self.assertGetNextTasks(queue, tasks)
# move backwards in the middle
new_pos = self.api.queues.move_task_backward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 2)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:3] + [tasks[0], tasks[3]]
self.assertQueueTasks(res.queue, changed_tasks)
# move backwards beyond the end
new_pos = self.api.queues.move_task_backward(
queue=queue, task=tasks[0], count=100
).position
self.assertEqual(new_pos, 3)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + [tasks[0]]
self.assertQueueTasks(res.queue, changed_tasks)
# move forwards in the middle
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 1)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = [tasks[1], tasks[0]] + tasks[2:]
self.assertQueueTasks(res.queue, changed_tasks)
# move forwards beyond the beginning
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=100
).position
self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
# move to back
new_pos = self.api.queues.move_task_to_back(
queue=queue, task=tasks[0]
).position
self.assertEqual(new_pos, 3)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + [tasks[0]]
self.assertQueueTasks(res.queue, changed_tasks)
# move to front
new_pos = self.api.queues.move_task_to_front(
queue=queue, task=tasks[0]
).position
self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
def test_get_all_ex(self):
queue_name = "TestTempQueue1"