Fix de-serializing from aborted state + refactor

This commit is contained in:
allegroai 2021-08-02 23:01:26 +03:00
parent 27b560bb75
commit 2ba653baf7

View File

@ -1,4 +1,5 @@
import json
import logging
from datetime import datetime
from threading import Thread, enumerate as enumerate_threads
from time import sleep, time
@ -13,55 +14,73 @@ from ..task import Task
@attrs
class ScheduleJob(object):
_weekdays_ind = ('monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday')
name = attrib(type=str)
base_task_id = attrib(type=str)
class BaseScheduleJob(object):
name = attrib(type=str, default=None)
base_task_id = attrib(type=str, default=None)
base_function = attrib(type=Callable, default=None)
queue = attrib(type=str, default=None)
target_project = attrib(type=str, default=None)
execution_limit_hours = attrib(type=float, default=None)
recurring = attrib(type=bool, default=True)
starting_time = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
single_instance = attrib(type=bool, default=False)
task_parameters = attrib(type=dict, default={})
task_overrides = attrib(type=dict, default={})
clone_task = attrib(type=bool, default=True)
minute = attrib(type=float, default=None)
hour = attrib(type=float, default=None)
day = attrib(default=None)
weekdays = attrib(default=None)
month = attrib(type=float, default=None)
year = attrib(type=float, default=None)
_executed_instances = attrib(type=list, default=[])
_next_run = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
_last_executed = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
_execution_timeout = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
def to_dict(self, full=False):
return {k: v for k, v in self.__dict__.items()
if not callable(v) and (full or not str(k).startswith('_'))}
def update(self, a_job):
for k, v in a_job.to_dict().items():
if not callable(getattr(self, k, v)):
setattr(self, k, v)
converters = {a.name: a.converter for a in getattr(self, '__attrs_attrs__', [])}
for k, v in (a_job.to_dict(full=True) if not isinstance(a_job, dict) else a_job).items():
if v is not None and not callable(getattr(self, k, v)):
setattr(self, k, converters[k](v) if converters.get(k) else v)
return self
def verify(self):
# type () -> None
if self.weekdays and self.day not in (None, 0, 1):
raise ValueError("`weekdays` and `day` combination is not valid (day must be None,0 or 1)")
if self.base_function and not self.name:
raise ValueError("Entry 'name' must be supplied for function scheduling")
if self.base_task_id and not self.queue:
raise ValueError("Target 'queue' must be provided for function scheduling")
if not (self.minute or self.hour or self.day or self.month or self.year):
raise ValueError("Schedule time/date was not provided")
if not self.base_function and not self.base_task_id:
raise ValueError("Either schedule function or task-id must be provided")
def get_last_executed_task_id(self):
# type () -> Optional[str]
return self._executed_instances[-1] if self._executed_instances else None
def run(self, task_id):
# type (Optional[str]) -> datetime
if task_id:
self._executed_instances.append(str(task_id))
@attrs
class ScheduleJob(BaseScheduleJob):
_weekdays_ind = ('monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday')
execution_limit_hours = attrib(type=float, default=None)
recurring = attrib(type=bool, default=True)
starting_time = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
minute = attrib(type=float, default=None)
hour = attrib(type=float, default=None)
day = attrib(default=None)
weekdays = attrib(default=None)
month = attrib(type=float, default=None)
year = attrib(type=float, default=None)
_next_run = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
_execution_timeout = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
_last_executed = attrib(type=datetime, converter=datetime_from_isoformat, default=None)
def verify(self):
# type () -> None
super(ScheduleJob, self).verify()
if self.weekdays and self.day not in (None, 0, 1):
raise ValueError("`weekdays` and `day` combination is not valid (day must be None,0 or 1)")
if not (self.minute or self.hour or self.day or self.month or self.year):
raise ValueError("Schedule time/date was not provided")
def next_run(self):
# type () -> Optional[datetime]
return self._next_run
@ -70,10 +89,6 @@ class ScheduleJob(object):
# type () -> Optional[datetime]
return self._execution_timeout
def get_last_executed_task_id(self):
# type () -> Optional[str]
return self._executed_instances[-1] if self._executed_instances else None
def next(self):
# type () -> Optional[datetime]
"""
@ -194,9 +209,8 @@ class ScheduleJob(object):
def run(self, task_id):
# type (Optional[str]) -> datetime
super(ScheduleJob, self).run(task_id)
self._last_executed = datetime.utcnow()
if task_id:
self._executed_instances.append(str(task_id))
if self.execution_limit_hours and task_id:
self._execution_timeout = self._last_executed + relativedelta(
hours=int(self.execution_limit_hours),
@ -226,7 +240,205 @@ class ExecutedJob(object):
return {k: v for k, v in self.__dict__.items() if full or not str(k).startswith('_')}
class TaskScheduler(object):
class BaseScheduler(object):
def __init__(self, sync_frequency_minutes=15, force_create_task_name=None, force_create_task_project=None):
# type: (float, Optional[str], Optional[str]) -> None
"""
Create a Task scheduler service
:param sync_frequency_minutes: Sync task scheduler configuration every X minutes.
Allow to change scheduler in runtime by editing the Task configuration object
:param force_create_task_name: Optional, force creation of Task Scheduler service,
even if main Task.init already exists.
:param force_create_task_project: Optional, force creation of Task Scheduler service,
even if main Task.init already exists.
"""
self._last_sync = 0
self._sync_frequency_minutes = sync_frequency_minutes
if force_create_task_name or not Task.current_task():
self._task = Task.init(
project_name=force_create_task_project or 'DevOps',
task_name=force_create_task_name or 'Scheduler',
task_type=Task.TaskTypes.service,
auto_resource_monitoring=False,
)
else:
self._task = Task.current_task()
def start(self):
# type: () -> None
"""
Start the Task TaskScheduler loop (notice this function does not return)
"""
if Task.running_locally():
self._serialize_state()
self._serialize()
else:
self._deserialize_state()
self._deserialize()
while True:
# sync with backend
try:
if time() - self._last_sync > 60. * self._sync_frequency_minutes:
self._last_sync = time()
self._deserialize()
self._update_execution_plots()
except Exception as ex:
self._log('Warning: Exception caught during deserialization: {}'.format(ex))
self._last_sync = time()
try:
if self._step():
self._serialize_state()
self._update_execution_plots()
except Exception as ex:
self._log('Warning: Exception caught during scheduling step: {}'.format(ex))
# rate control
sleep(15)
def start_remotely(self, queue='services'):
# type: (str) -> None
"""
Start the Task TaskScheduler loop (notice this function does not return)
:param queue: Remote queue to run the scheduler on, default 'services' queue.
"""
self._task.execute_remotely(queue_name=queue, exit_process=True)
self.start()
def _update_execution_plots(self):
# type: () -> None
"""
Update the configuration and execution table plots
"""
pass
def _serialize(self):
# type: () -> None
"""
Serialize Task scheduling configuration only (no internal state)
"""
pass
def _serialize_state(self):
# type: () -> None
"""
Serialize internal state only
"""
pass
def _deserialize_state(self):
# type: () -> None
"""
Deserialize internal state only
"""
pass
def _deserialize(self):
# type: () -> None
"""
Deserialize Task scheduling configuration only
"""
pass
def _step(self):
# type: () -> bool
"""
scheduling processing step. Return True if a new Task was scheduled.
"""
pass
def _log(self, message, level=logging.INFO):
if self._task:
self._task.get_logger().report_text(message, level=level)
else:
print(message)
def _launch_job(self, job):
# type: (ScheduleJob) -> None
self._launch_job_task(job)
self._launch_job_function(job)
def _launch_job_task(self, job):
# type: (BaseScheduleJob) -> Optional[ClearmlJob]
# make sure this is not a function job
if job.base_function:
return None
# check if this is a single instance, then we need to abort the Task
if job.single_instance and job.get_last_executed_task_id():
t = Task.get_task(task_id=job.get_last_executed_task_id())
if t.status in ('in_progress', 'queued'):
self._log(
'Skipping Task {} scheduling, previous Task instance {} still running'.format(
job.name, t.id
))
job.run(None)
return None
# actually run the job
task_job = ClearmlJob(
base_task_id=job.base_task_id,
parameter_override=job.task_parameters,
task_overrides=job.task_overrides,
disable_clone_task=not job.clone_task,
allow_caching=False,
target_project=job.target_project,
)
self._log('Scheduling Job {}, Task {} on queue {}.'.format(
job.name, task_job.task_id(), job.queue))
if task_job.launch(queue_name=job.queue):
# mark as run
job.run(task_job.task_id())
return task_job
def _launch_job_function(self, job):
# type: (ScheduleJob) -> Optional[Thread]
# make sure this IS a function job
if not job.base_function:
return None
# check if this is a single instance, then we need to abort the Task
if job.single_instance and job.get_last_executed_task_id():
# noinspection PyBroadException
try:
a_thread = [t for t in enumerate_threads() if t.ident == job.get_last_executed_task_id()]
if a_thread:
a_thread = a_thread[0]
except Exception:
a_thread = None
if a_thread and a_thread.is_alive():
self._log(
"Skipping Task '{}' scheduling, previous Thread instance '{}' still running".format(
job.name, a_thread.ident
))
job.run(None)
return None
self._log("Scheduling Job '{}', Task '{}' on background thread".format(
job.name, job.base_function))
t = Thread(target=job.base_function)
t.start()
# mark as run
job.run(t.ident)
return t
@staticmethod
def _cancel_task(task_id):
# type: (str) -> ()
if not task_id:
return
t = Task.get_task(task_id=task_id)
status = t.status
if status in ('in_progress',):
t.stopped(force=True)
elif status in ('queued',):
Task.dequeue(t)
class TaskScheduler(BaseScheduler):
"""
Task Scheduling controller.
Notice time-zone is ALWAYS UTC
@ -245,21 +457,14 @@ class TaskScheduler(object):
:param force_create_task_project: Optional, force creation of Task Scheduler service,
even if main Task.init already exists.
"""
self._last_sync = 0
self._sync_frequency_minutes = sync_frequency_minutes
super(TaskScheduler, self).__init__(
sync_frequency_minutes=sync_frequency_minutes,
force_create_task_name=force_create_task_name,
force_create_task_project=force_create_task_project
)
self._schedule_jobs = [] # List[ScheduleJob]
self._timeout_jobs = {} # Dict[datetime, str]
self._executed_jobs = [] # List[ExecutedJob]
self._thread = None
if force_create_task_name or not Task.current_task():
self._task = Task.init(
project_name=force_create_task_project or 'DevOps',
task_name=force_create_task_name or 'Scheduler',
task_type=Task.TaskTypes.service,
auto_resource_monitoring=False,
)
else:
self._task = Task.current_task()
def add_task(
self,
@ -394,32 +599,13 @@ class TaskScheduler(object):
"""
Start the Task TaskScheduler loop (notice this function does not return)
"""
if Task.running_locally():
self._serialize_state()
self._serialize()
else:
self._deserialize_state()
self._deserialize()
while True:
try:
self._step()
except Exception as ex:
self._log('Warning: Exception caught during scheduling step: {}'.format(ex))
# rate control
sleep(15)
super(TaskScheduler, self).start()
def _step(self):
# type: () -> None
# type: () -> bool
"""
scheduling processing step
"""
# sync with backend
if time() - self._last_sync > 60. * self._sync_frequency_minutes:
self._last_sync = time()
self._deserialize()
self._update_execution_plots()
# update next execution datetime
for j in self._schedule_jobs:
j.next()
@ -435,7 +621,7 @@ class TaskScheduler(object):
seconds = 60. * self._sync_frequency_minutes
self._log('Nothing to do, sleeping for {:.2f} minutes.'.format(seconds / 60.))
sleep(seconds)
return
return False
next_time_stamp = scheduled_jobs[0].next_run() if scheduled_jobs else None
if timeout_jobs:
@ -448,7 +634,7 @@ class TaskScheduler(object):
seconds = min(sleep_time, 60. * self._sync_frequency_minutes)
self._log('Waiting for next run, sleeping for {:.2f} minutes, until next sync.'.format(seconds / 60.))
sleep(seconds)
return
return False
# check if this is a Task timeout check
if timeout_jobs and next_time_stamp == timeout_jobs[0]:
@ -461,7 +647,7 @@ class TaskScheduler(object):
self._log('Launching job: {}'.format(scheduled_jobs[0]))
self._launch_job(scheduled_jobs[0])
self._update_execution_plots()
return True
def start_remotely(self, queue='services'):
# type: (str) -> None
@ -470,8 +656,7 @@ class TaskScheduler(object):
:param queue: Remote queue to run the scheduler on, default 'services' queue.
"""
self._task.execute_remotely(queue_name=queue, exit_process=True)
self.start()
super(TaskScheduler, self).start_remotely(queue=queue)
def _serialize(self):
# type: () -> None
@ -514,10 +699,13 @@ class TaskScheduler(object):
self._task.reload()
artifact_object = self._task.artifacts.get('state')
if artifact_object is not None:
state_json_str = artifact_object.get()
state_json_str = artifact_object.get(force_download=True)
if state_json_str is not None:
state_dict = json.loads(state_json_str)
self._schedule_jobs = [ScheduleJob(**j) for j in state_dict.get('scheduled_jobs', [])]
self._schedule_jobs = self.__deserialize_scheduled_jobs(
serialized_jobs_dicts=state_dict.get('scheduled_jobs', []),
current_jobs=self._schedule_jobs
)
self._timeout_jobs = state_dict.get('timeout_jobs') or {}
self._executed_jobs = [ExecutedJob(**j) for j in state_dict.get('executed_jobs', [])]
@ -531,23 +719,32 @@ class TaskScheduler(object):
# noinspection PyProtectedMember
json_str = self._task._get_configuration_text(name=self._configuration_section)
try:
scheduled_jobs = [ScheduleJob(**j) for j in json.loads(json_str)]
self._schedule_jobs = self.__deserialize_scheduled_jobs(
serialized_jobs_dicts=json.loads(json_str),
current_jobs=self._schedule_jobs
)
except Exception as ex:
self._log('Failed deserializing configuration: {}'.format(ex), level='warning')
self._log('Failed deserializing configuration: {}'.format(ex), level=logging.WARN)
return
@staticmethod
def __deserialize_scheduled_jobs(serialized_jobs_dicts, current_jobs):
# type(List[Dict], List[ScheduleJob]) -> List[ScheduleJob]
scheduled_jobs = [ScheduleJob().update(j) for j in serialized_jobs_dicts]
scheduled_jobs = {j.name: j for j in scheduled_jobs}
current_scheduled_jobs = {j.name: j for j in self._schedule_jobs}
current_scheduled_jobs = {j.name: j for j in current_jobs}
# select only valid jobs, and update the valid ones state from the current one
self._schedule_jobs = [
new_scheduled_jobs = [
current_scheduled_jobs[name].update(j) if name in current_scheduled_jobs else j
for name, j in scheduled_jobs.items()
]
# verify all jobs
for j in self._schedule_jobs:
for j in new_scheduled_jobs:
j.verify()
return new_scheduled_jobs
def _serialize_schedule_into_string(self):
# type: () -> str
return json.dumps([j.to_dict() for j in self._schedule_jobs], default=datetime_to_isoformat)
@ -619,48 +816,11 @@ class TaskScheduler(object):
table_plot=executed_table
)
def _log(self, message, level=None):
if self._task:
self._task.get_logger().report_text(message)
else:
print(message)
def _launch_job(self, job):
# type: (ScheduleJob) -> None
self._launch_job_task(job)
self._launch_job_function(job)
def _launch_job_task(self, job):
# type: (ScheduleJob) -> None
task_job = super(TaskScheduler, self)._launch_job_task(job)
# make sure this is not a function job
if job.base_function:
return
# check if this is a single instance, then we need to abort the Task
if job.single_instance and job.get_last_executed_task_id():
t = Task.get_task(task_id=job.get_last_executed_task_id())
if t.status in ('in_progress', 'queued'):
self._log(
'Skipping Task {} scheduling, previous Task instance {} still running'.format(
job.name, t.id
))
job.run(None)
return
# actually run the job
task_job = ClearmlJob(
base_task_id=job.base_task_id,
parameter_override=job.task_parameters,
task_overrides=job.task_overrides,
disable_clone_task=not job.clone_task,
allow_caching=False,
target_project=job.target_project,
)
self._log('Scheduling Job {}, Task {} on queue {}.'.format(
job.name, task_job.task_id(), job.queue))
if task_job.launch(queue_name=job.queue):
# mark as run
job.run(task_job.task_id())
if task_job:
self._executed_jobs.append(ExecutedJob(
name=job.name, task_id=task_job.task_id(), started=datetime.utcnow()))
# add timeout check
@ -670,44 +830,9 @@ class TaskScheduler(object):
def _launch_job_function(self, job):
# type: (ScheduleJob) -> None
# make sure this IS a function job
if not job.base_function:
return
# check if this is a single instance, then we need to abort the Task
if job.single_instance and job.get_last_executed_task_id():
# noinspection PyBroadException
try:
a_thread = [t for t in enumerate_threads() if t.ident == job.get_last_executed_task_id()]
if a_thread:
a_thread = a_thread[0]
except Exception:
a_thread = None
if a_thread and a_thread.is_alive():
self._log(
"Skipping Task '{}' scheduling, previous Thread instance '{}' still running".format(
job.name, a_thread.ident
))
job.run(None)
return
self._log("Scheduling Job '{}', Task '{}' on background thread".format(
job.name, job.base_function))
t = Thread(target=job.base_function)
t.start()
# mark as run
job.run(t.ident)
self._executed_jobs.append(ExecutedJob(
name=job.name, thread_id=str(t.ident), started=datetime.utcnow()))
# execution timeout is not supported with function callbacks.
@staticmethod
def _cancel_task(task_id):
# type: (str) -> ()
t = Task.get_task(task_id=task_id)
status = t.status
if status in ('in_progress',):
t.stopped(force=True)
elif status in ('queued',):
Task.dequeue(t)
thread_job = super(TaskScheduler, self)._launch_job_function(job)
# make sure this is not a function job
if thread_job:
self._executed_jobs.append(ExecutedJob(
name=job.name, thread_id=str(thread_job.ident), started=datetime.utcnow()))
# execution timeout is not supported with function callbacks.