diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index be8dd670..d886d212 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -124,6 +124,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._app_server = None self._files_server = None self._initial_iteration_offset = 0 + self._reload_skip_flag = False if not task_id: # generate a new task @@ -456,6 +457,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _reload(self): """ Reload the task object from the backend """ with self._edit_lock: + if self._reload_skip_flag and self._data: + return self._data res = self.send(tasks.GetByIdRequest(task=self.id)) return res.response.task @@ -464,6 +467,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self.send(tasks.ResetRequest(task=self.id)) if set_started_on_success: self.started() + elif self._data: + # if not started, make sure the current cached state is synced + self._data.status = self.TaskStatusEnum.created + self.reload() def started(self, ignore_errors=True): @@ -1055,6 +1062,29 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): except Exception: return None + def _clear_task(self, system_tags=None, comment=None): + self._data.script = tasks.Script( + binary='', repository='', tag='', branch='', version_num='', entry_point='', + working_dir='', requirements={}, diff='', + ) + self._data.execution = tasks.Execution( + artifacts=[], dataviews=[], model='', model_desc={}, model_labels={}, parameters={}, docker_cmd='') + self._data.comment = str(comment) + + self._storage_uri = None + self._data.output.destination = self._storage_uri + + self._update_requirements('') + + if Session.check_min_api_version('2.3'): + self._set_task_property("system_tags", system_tags) + self._edit(system_tags=self._data.system_tags, comment=self._data.comment, + script=self._data.script, execution=self._data.execution, output_dest='') + else: + self._set_task_property("tags", system_tags) + self._edit(tags=self._data.tags, comment=self._data.comment, + script=self._data.script, execution=self._data.execution, output_dest=None) + @classmethod def _get_api_server(cls): return Session.get_api_server_host() @@ -1067,8 +1097,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _edit(self, **kwargs): with self._edit_lock: # Since we ae using forced update, make sure he task status is valid - if not self._data or (str(self.data.status) not in (str(tasks.TaskStatusEnum.created), - str(tasks.TaskStatusEnum.in_progress))): + status = self._data.status if self._data and self._reload_skip_flag else self.data.status + if status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress): # the exception being name/comment that we can always change. if kwargs and all(k in ('name', 'comment') for k in kwargs.keys()): pass diff --git a/trains/task.py b/trains/task.py index 69004815..1a0f6294 100644 --- a/trains/task.py +++ b/trains/task.py @@ -1489,7 +1489,7 @@ class Task(_Task): task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags task_artifacts = task.data.execution.artifacts \ if hasattr(task.data.execution, 'artifacts') else None - if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) + if ((str(task._status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) or task.output_model_id or (ARCHIVED_TAG in task_tags) or (cls._development_tag not in task_tags) or task_artifacts): @@ -1499,20 +1499,16 @@ class Task(_Task): default_task_id = None task = None else: - # reset the task, so we can update it - task.reset(set_started_on_success=False, force=False) - # set development tags - task.set_system_tags([cls._development_tag]) - # clear task parameters, they are not cleared by the Task reset - task.set_parameters({}, __update=False) - # clear the comment, it is not cleared on reset - task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s')) - # clear the input model (and task model design/labels) - task.set_input_model(model_id='', update_task_design=False, update_task_labels=False) - task._set_model_config(config_text='') - task.set_model_label_enumeration({}) - task.set_artifacts([]) - task._set_storage_uri(None) + with task._edit_lock: + # from now on, there is no need to reload, we just clear stuff, + # this flag will be cleared off once we actually refresh at the end of the function + task._reload_skip_flag = True + # reset the task, so we can update it + task.reset(set_started_on_success=False, force=False) + # clear the heaviest stuff first + task._clear_task( + system_tags=[cls._development_tag], + comment=make_message('Auto-generated at %(time)s by %(user)s@%(host)s')) except (Exception, ValueError): # we failed reusing task, create a new one @@ -1527,6 +1523,8 @@ class Task(_Task): task_type=default_task_type, log_to_backend=True, ) + # no need to reload yet, we clear this before the end of the function + task._reload_skip_flag = True if in_dev_mode: # update this session, for later use @@ -1536,6 +1534,10 @@ class Task(_Task): # mark the task as started task.started() + # reload, making sure we are synced + task._reload_skip_flag = False + task.reload() + # force update of base logger to this current task (this is the main logger task) task._setup_log(replace_existing=True) logger = task.get_logger() @@ -1556,12 +1558,11 @@ class Task(_Task): else: task._update_repository() - # make sure everything is in sync - task.reload() # make sure we see something in the UI thread = threading.Thread(target=LoggerRoot.flush) thread.daemon = True thread.start() + return task def _get_logger(self, flush_period=NotSet):