Add "queue watched" indication to pipelines.start_pipeline

This commit is contained in:
allegroai 2024-01-10 15:15:43 +02:00
parent 49084a9c49
commit 91ce140901
4 changed files with 64 additions and 30 deletions

View File

@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)
verify_watched_queue = fields.BoolField(default=False)

View File

@ -79,4 +79,15 @@ start_pipeline {
}
}
}
"999.0": ${start_pipeline."2.17"} {
request.properties.verify_watched_queue {
description: If passed then check wheter there are any workers watiching the queue
type: boolean
default: false
}
response.properties.queue_watched {
description: Returns true if there are workers or autscalers working with the queue
type: boolean
}
}
}

View File

@ -5,22 +5,24 @@ import attr
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineResponse,
StartPipelineRequest,
DeleteRunsRequest,
)
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task, delete_task
from apiserver.bll.util import run_batch_operation
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
queue_bll = QueueBLL()
def _update_task_name(task: Task):
@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
call.result.data = dict(succeeded=succeeded, failed=failures)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
@endpoint("pipelines.start_pipeline")
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
@ -113,5 +113,14 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
status_message="Starting pipeline",
status_reason="",
)
extra = {}
if request.verify_watched_queue and queued:
res_queue = nested_get(res, ("fields", "execution.queue"))
if res_queue:
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
call.result.data = dict(
pipeline=task.id,
enqueued=bool(queued),
**extra,
)

View File

@ -37,29 +37,44 @@ class TestPipelines(TestService):
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
pipeline_task = res.pipeline
try:
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
finally:
self.api.tasks.delete(task=pipeline_task, force=True)
self.assertTrue(res.enqueued)
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
self.assertTrue(pipeline.name.startswith(task_name))
self.assertEqual(pipeline.status, "queued")
self.assertEqual(pipeline.project.id, project)
self.assertEqual(
pipeline.hyperparams.Args,
{
a["name"]: {
"section": "Args",
"name": a["name"],
"value": a["value"],
}
for a in args
},
)
# watched queue
queue = self._temp_queue("test pipelines")
project, task = self._temp_project_and_task(name="pipelines test1")
res = self.api.pipelines.start_pipeline(
task=task, queue=queue, verify_watched_queue=True
)
self.assertEqual(res.queue_watched, False)
self.api.workers.register(worker="test pipelines", queues=[queue])
project, task = self._temp_project_and_task(name="pipelines test2")
res = self.api.pipelines.start_pipeline(
task=task, queue=queue, verify_watched_queue=True
)
self.assertEqual(res.queue_watched, True)
def _temp_project_and_task(self, name) -> Tuple[str, str]:
project = self.create_temp(
"projects", name=name, description="test", delete_params=dict(force=True),
"projects",
name=name,
description="test",
delete_params=dict(force=True, delete_contents=True),
)
return (
@ -72,3 +87,6 @@ class TestPipelines(TestService):
system_tags=["pipeline"],
),
)
def _temp_queue(self, queue_name, **kwargs):
return self.create_temp("queues", name=queue_name, **kwargs)