Fix using target_project with TaskScheduler.add_task() corrupts project if used with pipelines (#1137)

This commit is contained in:
allegroai 2023-12-13 16:42:31 +02:00
parent 8cb7d14abb
commit 1076d20808
2 changed files with 17 additions and 3 deletions

View File

@ -69,6 +69,7 @@ class PipelineController(object):
_final_failure = {} # Node.name: bool
_task_template_header = CreateFromFunction.default_task_template_header
_default_pipeline_version = "1.0.0"
_project_section = ".pipelines"
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
@ -305,7 +306,7 @@ class PipelineController(object):
if not self._task:
task_name = name or project or '{}'.format(datetime.now())
if self._pipeline_as_sub_project:
parent_project = "{}.pipelines".format(project+'/' if project else '')
parent_project = (project + "/" if project else "") + self._pipeline_section
project_name = "{}/{}".format(parent_project, task_name)
else:
parent_project = None
@ -1422,7 +1423,7 @@ class PipelineController(object):
mutually_exclusive(pipeline_id=pipeline_id, pipeline_project=pipeline_project, _require_at_least_one=False)
mutually_exclusive(pipeline_id=pipeline_id, pipeline_name=pipeline_name, _require_at_least_one=False)
if not pipeline_id:
pipeline_project_hidden = "{}/.pipelines/{}".format(pipeline_project, pipeline_name)
pipeline_project_hidden = "{}/{}/{}".format(pipeline_project, cls._pipeline_section, pipeline_name)
name_with_runtime_number_regex = r"^{}( #[0-9]+)*$".format(re.escape(pipeline_name))
pipelines = Task._query_tasks(
pipeline_project=[pipeline_project_hidden],

View File

@ -9,6 +9,7 @@ from attr import attrs, attrib
from dateutil.relativedelta import relativedelta
from .job import ClearmlJob
from .controller import PipelineController
from ..backend_interface.util import datetime_from_isoformat, datetime_to_isoformat, mutually_exclusive
from ..task import Task
@ -59,6 +60,18 @@ class BaseScheduleJob(object):
self._executed_instances = []
self._executed_instances.append(str(task_id))
def get_resolved_target_project(self):
if not self.base_task_id or not self.target_project:
return self.target_project
# noinspection PyBroadException
try:
task = Task.get_task(task_id=self.base_task_id)
if PipelineController._tag in task.get_system_tags() and "/{}/".format(PipelineController._pipeline_section) not in self.target_project:
return "{}/{}/{}".format(self.target_project, PipelineController._pipeline_section, task.name)
except Exception:
pass
return self.target_project
@attrs
class ScheduleJob(BaseScheduleJob):
@ -447,7 +460,7 @@ class BaseScheduler(object):
task_overrides=job.task_overrides,
disable_clone_task=not job.clone_task,
allow_caching=False,
target_project=job.target_project,
target_project=job.get_resolved_target_project(),
tags=[add_tags] if add_tags and isinstance(add_tags, str) else add_tags,
)
self._log('Scheduling Job {}, Task {} on queue {}.'.format(