feat(scheduler): Adds task_id_function to TaskScheduler.add_task()

Adds new functionality to the task scheduler by adding a new parameter task_id_function to TaskScheduler.add_task() that takes a callable which has an expected return of a task_id. This task_id_function function is run at runtime (when the task scheduler would normally execute the scheduled task) and uses the task_id returned by the function + the other parameters from .add_task() as the scheduled task.
This commit is contained in:
Nathan B 2024-02-23 17:49:19 +01:00
parent beb0f5e607
commit 3a8f77f6d3

View File

@ -19,6 +19,7 @@ class BaseScheduleJob(object):
name = attrib(type=str, default=None)
base_task_id = attrib(type=str, default=None)
base_function = attrib(type=Callable, default=None)
base_task_id_function = attrib(type=Callable, default=None)
queue = attrib(type=str, default=None)
target_project = attrib(type=str, default=None)
single_instance = attrib(type=bool, default=False)
@ -43,9 +44,13 @@ class BaseScheduleJob(object):
if self.base_function and not self.name:
raise ValueError("Entry 'name' must be supplied for function scheduling")
if self.base_task_id and not self.queue:
raise ValueError("Target 'queue' must be provided for function scheduling")
if not self.base_function and not self.base_task_id:
raise ValueError("Either schedule function or task-id must be provided")
raise ValueError("Target 'queue' must be provided for task scheduling")
if self.base_task_id_function and not self.queue:
raise ValueError("Target 'queue' must be provided for task-id function scheduling")
if not self.base_function and not self.base_task_id and not self.base_task_id_function:
raise ValueError("Either schedule function, task-id, or task-id function must be provided")
if self.base_task_id and self.base_task_id_function:
raise ValueError("Only one of task-id or task-id function must be provided")
def get_last_executed_task_id(self):
# type: () -> Optional[str]
@ -60,11 +65,12 @@ class BaseScheduleJob(object):
self._executed_instances.append(str(task_id))
def get_resolved_target_project(self):
if not self.base_task_id or not self.target_project:
if not self.base_task_id or not self.target_project or not self.base_task_id_function:
return self.target_project
# noinspection PyBroadException
try:
task = Task.get_task(task_id=self.base_task_id)
base_task_id = self.get_task_id()
task = Task.get_task(task_id=base_task_id)
# noinspection PyProtectedMember
if (
PipelineController._tag in task.get_system_tags()
@ -76,6 +82,19 @@ class BaseScheduleJob(object):
pass
return self.target_project
def get_task_id(self):
# type: () -> str
if self.base_task_id_function:
# validate retrevial of a valid task id
try:
base_task_id = self.base_task_id_function()
except Exception as ex:
raise ValueError("Failed to retrieve task id from function: {}".format(ex))
return base_task_id
return self.base_task_id
@attrs
class ScheduleJob(BaseScheduleJob):
@ -459,9 +478,10 @@ class BaseScheduler(object):
job.run(None)
return None
# actually run the job
# actually run the job'
base_task_id = job.get_task_id()
task_job = ClearmlJob(
base_task_id=job.base_task_id,
base_task_id=base_task_id,
parameter_override=task_parameters or job.task_parameters,
task_overrides=job.task_overrides,
disable_clone_task=not job.clone_task,
@ -469,10 +489,10 @@ class BaseScheduler(object):
target_project=job.get_resolved_target_project(),
tags=[add_tags] if add_tags and isinstance(add_tags, str) else add_tags,
)
self._log("Scheduling Job {}, Task {} on queue {}.".format(job.name, task_job.task_id(), job.queue))
self._log('Scheduling Job {}, Task {} on queue {}.'.format(job.name, base_task_id, job.queue))
if task_job.launch(queue_name=job.queue):
# mark as run
job.run(task_job.task_id())
job.run(base_task_id)
return task_job
def _launch_job_function(self, job, func_args=None):
@ -553,6 +573,7 @@ class TaskScheduler(BaseScheduler):
self,
schedule_task_id=None, # type: Union[str, Task]
schedule_function=None, # type: Callable
task_id_function=None, # type: Callable
queue=None, # type: str
name=None, # type: Optional[str]
target_project=None, # type: Optional[str]
@ -606,6 +627,10 @@ class TaskScheduler(BaseScheduler):
:param schedule_function: Optional, instead of providing Task ID to be scheduled,
provide a function to be called. Notice the function is called from the scheduler context
(i.e. running on the same machine as the scheduler)
:param task_id_function: Optional, instead of providing Task ID to be scheduled,
provide a function to be called that returns the Task ID. That task ID will be used to schedule the task at runtime.
Notice the function is called from the scheduler context
:param queue: Name or ID of queue to put the Task into (i.e. schedule)
:param name: Name or description for the cron Task (should be unique if provided, otherwise randomly generated)
:param target_project: Specify target project to put the cloned scheduled Task in.
@ -634,14 +659,15 @@ class TaskScheduler(BaseScheduler):
:return: True if job is successfully added to the scheduling list
"""
mutually_exclusive(schedule_function=schedule_function, schedule_task_id=schedule_task_id)
task_id = schedule_task_id.id if isinstance(schedule_task_id, Task) else str(schedule_task_id or "")
mutually_exclusive(schedule_function=schedule_function, schedule_task_id=schedule_task_id, task_id_function=task_id_function)
task_id = schedule_task_id.id if isinstance(schedule_task_id, Task) else str(schedule_task_id or '')
# noinspection PyProtectedMember
job = ScheduleJob(
name=name or task_id,
base_task_id=task_id,
base_function=schedule_function,
base_task_iÍd_function=task_id_function,
queue=queue,
target_project=target_project,
execution_limit_hours=limit_execution_time,