Fix error when using TaskScheduler with 'limit_execution_time' (#648)

This commit is contained in:
allegroai 2022-12-22 21:55:27 +02:00
parent dd79bd6197
commit 7d6eff4858

View File

@ -669,8 +669,9 @@ class TaskScheduler(BaseScheduler):
[j for j in self._schedule_jobs if j.next_run() is not None], [j for j in self._schedule_jobs if j.next_run() is not None],
key=lambda x: x.next_run() key=lambda x: x.next_run()
) )
timeout_jobs = sorted(list(self._timeout_jobs.values())) # sort by key
if not scheduled_jobs and not timeout_jobs: timeout_job_datetime = min(self._timeout_jobs, key=self._timeout_jobs.get) if self._timeout_jobs else None
if not scheduled_jobs and timeout_job_datetime is None:
# sleep and retry # sleep and retry
seconds = 60. * self._sync_frequency_minutes seconds = 60. * self._sync_frequency_minutes
self._log('Nothing to do, sleeping for {:.2f} minutes.'.format(seconds / 60.)) self._log('Nothing to do, sleeping for {:.2f} minutes.'.format(seconds / 60.))
@ -678,9 +679,10 @@ class TaskScheduler(BaseScheduler):
return False return False
next_time_stamp = scheduled_jobs[0].next_run() if scheduled_jobs else None next_time_stamp = scheduled_jobs[0].next_run() if scheduled_jobs else None
if timeout_jobs: if timeout_job_datetime is not None:
next_time_stamp = min(timeout_jobs[0], next_time_stamp) \ next_time_stamp = (
if next_time_stamp else timeout_jobs[0] min(next_time_stamp, timeout_job_datetime) if next_time_stamp else timeout_job_datetime
)
sleep_time = (next_time_stamp - datetime.utcnow()).total_seconds() sleep_time = (next_time_stamp - datetime.utcnow()).total_seconds()
if sleep_time > 0: if sleep_time > 0:
@ -691,12 +693,11 @@ class TaskScheduler(BaseScheduler):
return False return False
# check if this is a Task timeout check # check if this is a Task timeout check
if timeout_jobs and next_time_stamp == timeout_jobs[0]: if timeout_job_datetime is not None and next_time_stamp == timeout_job_datetime:
self._log('Aborting timeout job: {}'.format(timeout_jobs[0])) task_id = self._timeout_jobs[timeout_job_datetime]
# mark aborted self._log('Aborting job due to timeout: {}'.format(task_id))
task_id = [k for k, v in self._timeout_jobs.items() if v == timeout_jobs[0]][0]
self._cancel_task(task_id=task_id) self._cancel_task(task_id=task_id)
self._timeout_jobs.pop(task_id, None) self._timeout_jobs.pop(timeout_job_datetime, None)
else: else:
self._log('Launching job: {}'.format(scheduled_jobs[0])) self._log('Launching job: {}'.format(scheduled_jobs[0]))
self._launch_job(scheduled_jobs[0]) self._launch_job(scheduled_jobs[0])
@ -733,16 +734,12 @@ class TaskScheduler(BaseScheduler):
json_str = json.dumps( json_str = json.dumps(
dict( dict(
scheduled_jobs=[j.to_dict(full=True) for j in self._schedule_jobs], scheduled_jobs=[j.to_dict(full=True) for j in self._schedule_jobs],
timeout_jobs=self._timeout_jobs, timeout_jobs={datetime_to_isoformat(k): v for k, v in self._timeout_jobs.items()},
executed_jobs=[j.to_dict(full=True) for j in self._executed_jobs], executed_jobs=[j.to_dict(full=True) for j in self._executed_jobs]
), ),
default=datetime_to_isoformat default=datetime_to_isoformat
) )
self._task.upload_artifact( self._task.upload_artifact(name="state", artifact_object=json_str, preview="scheduler internal state")
name='state',
artifact_object=json_str,
preview='scheduler internal state'
)
def _deserialize_state(self): def _deserialize_state(self):
# type: () -> None # type: () -> None
@ -760,7 +757,7 @@ class TaskScheduler(BaseScheduler):
serialized_jobs_dicts=state_dict.get('scheduled_jobs', []), serialized_jobs_dicts=state_dict.get('scheduled_jobs', []),
current_jobs=self._schedule_jobs current_jobs=self._schedule_jobs
) )
self._timeout_jobs = state_dict.get('timeout_jobs') or {} self._timeout_jobs = {datetime_from_isoformat(k): v for k, v in (state_dict.get('timeout_jobs') or {})}
self._executed_jobs = [ExecutedJob(**j) for j in state_dict.get('executed_jobs', [])] self._executed_jobs = [ExecutedJob(**j) for j in state_dict.get('executed_jobs', [])]
def _deserialize(self): def _deserialize(self):