mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add "queue watched" indication to pipelines.start_pipeline
This commit is contained in:
parent
49084a9c49
commit
91ce140901
@ -18,8 +18,4 @@ class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
verify_watched_queue = fields.BoolField(default=False)
|
||||
|
@ -79,4 +79,15 @@ start_pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${start_pipeline."2.17"} {
|
||||
request.properties.verify_watched_queue {
|
||||
description: If passed then check wheter there are any workers watiching the queue
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
response.properties.queue_watched {
|
||||
description: Returns true if there are workers or autscalers working with the queue
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
@ -5,22 +5,24 @@ import attr
|
||||
|
||||
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
|
||||
from apiserver.apimodels.pipelines import (
|
||||
StartPipelineResponse,
|
||||
StartPipelineRequest,
|
||||
DeleteRunsRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import enqueue_task, delete_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskType
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
|
||||
def _update_task_name(task: Task):
|
||||
@ -79,9 +81,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
|
||||
call.result.data = dict(succeeded=succeeded, failed=failures)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
|
||||
)
|
||||
@endpoint("pipelines.start_pipeline")
|
||||
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
|
||||
hyperparams = None
|
||||
if request.args:
|
||||
@ -113,5 +113,14 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
)
|
||||
extra = {}
|
||||
if request.verify_watched_queue and queued:
|
||||
res_queue = nested_get(res, ("fields", "execution.queue"))
|
||||
if res_queue:
|
||||
extra["queue_watched"] = queue_bll.check_for_workers(company_id, res_queue)
|
||||
|
||||
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
|
||||
call.result.data = dict(
|
||||
pipeline=task.id,
|
||||
enqueued=bool(queued),
|
||||
**extra,
|
||||
)
|
||||
|
@ -37,29 +37,44 @@ class TestPipelines(TestService):
|
||||
|
||||
res = self.api.pipelines.start_pipeline(task=task, queue=queue, args=args)
|
||||
pipeline_task = res.pipeline
|
||||
try:
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
finally:
|
||||
self.api.tasks.delete(task=pipeline_task, force=True)
|
||||
self.assertTrue(res.enqueued)
|
||||
pipeline = self.api.tasks.get_all_ex(id=[pipeline_task]).tasks[0]
|
||||
self.assertTrue(pipeline.name.startswith(task_name))
|
||||
self.assertEqual(pipeline.status, "queued")
|
||||
self.assertEqual(pipeline.project.id, project)
|
||||
self.assertEqual(
|
||||
pipeline.hyperparams.Args,
|
||||
{
|
||||
a["name"]: {
|
||||
"section": "Args",
|
||||
"name": a["name"],
|
||||
"value": a["value"],
|
||||
}
|
||||
for a in args
|
||||
},
|
||||
)
|
||||
|
||||
# watched queue
|
||||
queue = self._temp_queue("test pipelines")
|
||||
project, task = self._temp_project_and_task(name="pipelines test1")
|
||||
res = self.api.pipelines.start_pipeline(
|
||||
task=task, queue=queue, verify_watched_queue=True
|
||||
)
|
||||
self.assertEqual(res.queue_watched, False)
|
||||
|
||||
self.api.workers.register(worker="test pipelines", queues=[queue])
|
||||
project, task = self._temp_project_and_task(name="pipelines test2")
|
||||
res = self.api.pipelines.start_pipeline(
|
||||
task=task, queue=queue, verify_watched_queue=True
|
||||
)
|
||||
self.assertEqual(res.queue_watched, True)
|
||||
|
||||
def _temp_project_and_task(self, name) -> Tuple[str, str]:
|
||||
project = self.create_temp(
|
||||
"projects", name=name, description="test", delete_params=dict(force=True),
|
||||
"projects",
|
||||
name=name,
|
||||
description="test",
|
||||
delete_params=dict(force=True, delete_contents=True),
|
||||
)
|
||||
|
||||
return (
|
||||
@ -72,3 +87,6 @@ class TestPipelines(TestService):
|
||||
system_tags=["pipeline"],
|
||||
),
|
||||
)
|
||||
|
||||
def _temp_queue(self, queue_name, **kwargs):
|
||||
return self.create_temp("queues", name=queue_name, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user