Add new_status field to tasks.dequeue and dequeue_many endpoints

This commit is contained in:
allegroai 2023-07-26 18:55:05 +03:00
parent 4eff657810
commit 6c5f966ed4
7 changed files with 29 additions and 1 deletions

View File

@ -98,6 +98,7 @@ class UpdateRequest(TaskUpdateRequest):
class DequeueRequest(UpdateRequest): class DequeueRequest(UpdateRequest):
remove_from_all_queues = BoolField(default=False) remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class EnqueueRequest(UpdateRequest): class EnqueueRequest(UpdateRequest):
@ -280,6 +281,7 @@ class StopManyRequest(TaskBatchRequest):
class DequeueManyRequest(TaskBatchRequest): class DequeueManyRequest(TaskBatchRequest):
remove_from_all_queues = BoolField(default=False) remove_from_all_queues = BoolField(default=False)
new_status = StringField()
class EnqueueManyRequest(TaskBatchRequest): class EnqueueManyRequest(TaskBatchRequest):

View File

@ -172,6 +172,7 @@ class QueueBLL(object):
status_reason="Queue was deleted", status_reason="Queue was deleted",
status_message="", status_message="",
user_id=user_id, user_id=user_id,
force=True,
).execute(enqueue_status=None) ).execute(enqueue_status=None)
except Exception as ex: except Exception as ex:
log.exception( log.exception(

View File

@ -463,6 +463,7 @@ class TaskBLL:
status_message: str, status_message: str,
status_reason: str, status_reason: str,
remove_from_all_queues=False, remove_from_all_queues=False,
new_status=None,
): ):
try: try:
cls.dequeue(task, company_id, silent_fail=True) cls.dequeue(task, company_id, silent_fail=True)
@ -478,10 +479,11 @@ class TaskBLL:
return ChangeStatusRequest( return ChangeStatusRequest(
task=task, task=task,
new_status=task.enqueue_status or TaskStatus.created, new_status=new_status or task.enqueue_status or TaskStatus.created,
status_reason=status_reason, status_reason=status_reason,
status_message=status_message, status_message=status_message,
user_id=user_id, user_id=user_id,
force=True,
).execute(enqueue_status=None) ).execute(enqueue_status=None)
@classmethod @classmethod

View File

@ -23,6 +23,7 @@ from apiserver.database.model.task.task import (
Execution, Execution,
DEFAULT_LAST_ITERATION, DEFAULT_LAST_ITERATION,
) )
from apiserver.database.utils import get_options
from apiserver.utilities.dicts import nested_set from apiserver.utilities.dicts import nested_set
log = config.logger(__file__) log = config.logger(__file__)
@ -102,7 +103,11 @@ def dequeue_task(
status_message: str, status_message: str,
status_reason: str, status_reason: str,
remove_from_all_queues: bool = False, remove_from_all_queues: bool = False,
new_status=None,
) -> Tuple[int, dict]: ) -> Tuple[int, dict]:
if new_status and new_status not in get_options(TaskStatus):
raise errors.bad_request.ValidationError(f"Invalid task status: {new_status}")
# get the task without write access to make sure that it actually exists # get the task without write access to make sure that it actually exists
task = Task.get( task = Task.get(
id=task_id, id=task_id,
@ -128,6 +133,7 @@ def dequeue_task(
status_message=status_message, status_message=status_message,
status_reason=status_reason, status_reason=status_reason,
remove_from_all_queues=remove_from_all_queues, remove_from_all_queues=remove_from_all_queues,
new_status=new_status,
) )
return 1, res return 1, res

View File

@ -1525,6 +1525,12 @@ dequeue {
default: false default: false
} }
} }
"2.26": ${dequeue."2.25"} {
request.properties.new_status {
type: string
description: The new status to assign to the task after the dequeue instead of the default one
}
}
} }
dequeue_many { dequeue_many {
"2.13": ${_definitions.change_many_request} { "2.13": ${_definitions.change_many_request} {
@ -1550,6 +1556,12 @@ dequeue_many {
default: false default: false
} }
} }
"2.26": ${dequeue_many."2.25"} {
request.properties.new_status {
type: string
description: The new status to assign to the task after the dequeue instead of the default one
}
}
} }
set_requirements { set_requirements {
"2.1" { "2.1" {

View File

@ -921,6 +921,7 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues, remove_from_all_queues=request.remove_from_all_queues,
new_status=request.new_status,
) )
call.result.data_model = DequeueResponse(dequeued=dequeued, **res) call.result.data_model = DequeueResponse(dequeued=dequeued, **res)
@ -937,6 +938,7 @@ def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues, remove_from_all_queues=request.remove_from_all_queues,
new_status=request.new_status,
), ),
ids=request.ids, ids=request.ids,
) )

View File

@ -84,6 +84,9 @@ class TestQueues(TestService):
res = self.api.queues.get_by_id(queue=queue) res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, [task]) self.assertQueueTasks(res.queue, [task])
self.assertTaskTags(task, system_tags=[]) self.assertTaskTags(task, system_tags=[])
self.api.tasks.dequeue(task=task, new_status="published")
res = self.api.tasks.get_by_id(task=task)
self.assertEqual(res.task.status, "published")
def test_dequeue_not_queued_task(self): def test_dequeue_not_queued_task(self):
# queue = self._temp_queue("TestTempQueue") # queue = self._temp_queue("TestTempQueue")