From 91ce140901ebef0cfa13a9d8f81805f67072161f Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 10 Jan 2024 15:15:43 +0200 Subject: [PATCH] Add "queue watched" indication to pipelines.start_pipeline --- apiserver/apimodels/pipelines.py | 6 +-- apiserver/schema/services/pipelines.conf | 11 ++++ apiserver/services/pipelines.py | 19 +++++-- apiserver/tests/automated/test_pipelines.py | 58 ++++++++++++++------- 4 files changed, 64 insertions(+), 30 deletions(-) diff --git a/apiserver/apimodels/pipelines.py b/apiserver/apimodels/pipelines.py index 49c7f51..797a65e 100644 --- a/apiserver/apimodels/pipelines.py +++ b/apiserver/apimodels/pipelines.py @@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base): task = fields.StringField(required=True) queue = fields.StringField(required=True) args = ListField(Arg) - - -class StartPipelineResponse(models.Base): - pipeline = fields.StringField(required=True) - enqueued = fields.BoolField(required=True) + verify_watched_queue = fields.BoolField(default=False) diff --git a/apiserver/schema/services/pipelines.conf b/apiserver/schema/services/pipelines.conf index 45caedc..21af5e1 100644 --- a/apiserver/schema/services/pipelines.conf +++ b/apiserver/schema/services/pipelines.conf @@ -79,4 +79,15 @@ start_pipeline { } } } + "999.0": ${start_pipeline."2.17"} { + request.properties.verify_watched_queue { + description: If passed then check wheter there are any workers watiching the queue + type: boolean + default: false + } + response.properties.queue_watched { + description: Returns true if there are workers or autscalers working with the queue + type: boolean + } + } } \ No newline at end of file diff --git a/apiserver/services/pipelines.py b/apiserver/services/pipelines.py index 72decf4..57fe23c 100644 --- a/apiserver/services/pipelines.py +++ b/apiserver/services/pipelines.py @@ -5,22 +5,24 @@ import attr from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns from apiserver.apimodels.pipelines import ( - StartPipelineResponse, StartPipelineRequest, DeleteRunsRequest, ) from apiserver.bll.organization import OrgBLL from apiserver.bll.project import ProjectBLL +from apiserver.bll.queue import QueueBLL from apiserver.bll.task import TaskBLL from apiserver.bll.task.task_operations import enqueue_task, delete_task from apiserver.bll.util import run_batch_operation from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskType from apiserver.service_repo import APICall, endpoint +from apiserver.utilities.dicts import nested_get org_bll = OrgBLL() project_bll = ProjectBLL() task_bll = TaskBLL() +queue_bll = QueueBLL() def _update_task_name(task: Task): @@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest): call.result.data = dict(succeeded=succeeded, failed=failures) -@endpoint( - "pipelines.start_pipeline", response_data_model=StartPipelineResponse, -) +@endpoint("pipelines.start_pipeline") def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest): hyperparams = None if request.args: @@ -113,5 +113,14 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest status_message="Starting pipeline", status_reason="", ) + extra = {} + if request.verify_watched_queue and queued: + res_queue = nested_get(res, ("fields", "execution.queue")) + if res_queue: + extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue) - return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued)) + call.result.data = dict( + pipeline=task.id, + enqueued=bool(queued), + **extra, + ) diff --git a/apiserver/tests/automated/test_pipelines.py b/apiserver/tests/automated/test_pipelines.py index caf38db..ac4144a 100644 --- a/apiserver/tests/automated/test_pipelines.py +++ b/apiserver/tests/automated/test_pipelines.py @@ -37,29 +37,44 @@ class TestPipelines(TestService): res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args) pipeline_task = res.pipeline - try: - self.assertTrue(res.enqueued) - pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0] - self.assertTrue(pipeline.name.startswith(task_name)) - self.assertEqual(pipeline.status, "queued") - self.assertEqual(pipeline.project.id, project) - self.assertEqual( - pipeline.hyperparams.Args, - { - a["name"]: { - "section": "Args", - "name": a["name"], - "value": a["value"], - } - for a in args - }, - ) - finally: - self.api.tasks.delete(task=pipeline_task, force=True) + self.assertTrue(res.enqueued) + pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0] + self.assertTrue(pipeline.name.startswith(task_name)) + self.assertEqual(pipeline.status, "queued") + self.assertEqual(pipeline.project.id, project) + self.assertEqual( + pipeline.hyperparams.Args, + { + a["name"]: { + "section": "Args", + "name": a["name"], + "value": a["value"], + } + for a in args + }, + ) + + # watched queue + queue = self._temp_queue("test pipelines") + project, task = self._temp_project_and_task(name="pipelines test1") + res = self.api.pipelines.start_pipeline( + task=task, queue=queue, verify_watched_queue=True + ) + self.assertEqual(res.queue_watched, False) + + self.api.workers.register(worker="test pipelines", queues=[queue]) + project, task = self._temp_project_and_task(name="pipelines test2") + res = self.api.pipelines.start_pipeline( + task=task, queue=queue, verify_watched_queue=True + ) + self.assertEqual(res.queue_watched, True) def _temp_project_and_task(self, name) -> Tuple[str, str]: project = self.create_temp( - "projects", name=name, description="test", delete_params=dict(force=True), + "projects", + name=name, + description="test", + delete_params=dict(force=True, delete_contents=True), ) return ( @@ -72,3 +87,6 @@ class TestPipelines(TestService): system_tags=["pipeline"], ), ) + + def _temp_queue(self, queue_name, **kwargs): + return self.create_temp("queues", name=queue_name, **kwargs)