From 925450c7effc1c5b8a097fbea8620c3b4689a30d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 14 Jul 2020 23:36:03 +0300 Subject: [PATCH] Add Task.init() argument continue_last_task to continue a previously used Task (issue #160) --- trains/backend_interface/task/task.py | 3 +- trains/task.py | 113 +++++++++++++++++--------- 2 files changed, 78 insertions(+), 38 deletions(-) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 34908964..74aff7ae 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -238,7 +238,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # check latest version from ...utilities.check_updates import CheckPackageUpdates latest_version = CheckPackageUpdates.check_new_package_available(only_once=True) - if latest_version and not SUPPRESS_UPDATE_MESSAGE_ENV_VAR.get(default=config.get('development.suppress_update_message', False)): + if latest_version and not SUPPRESS_UPDATE_MESSAGE_ENV_VAR.get( + default=config.get('development.suppress_update_message', False)): if not latest_version[1]: sep = os.linesep self.get_logger().report_text( diff --git a/trains/task.py b/trains/task.py index 817531fd..9ec1d2de 100644 --- a/trains/task.py +++ b/trains/task.py @@ -175,7 +175,8 @@ class Task(_Task): project_name=None, # type: Optional[str] task_name=None, # type: Optional[str] task_type=TaskTypes.training, # type: Task.TaskTypes - reuse_last_task_id=True, # type: bool + reuse_last_task_id=True, # type: Union[bool, str] + continue_last_task=False, # type: Union[bool, str] output_uri=None, # type: Optional[str] auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]] @@ -241,18 +242,37 @@ class Task(_Task): - ``TaskTypes.qc`` - ``TaskTypes.custom`` - :param bool reuse_last_task_id: Force a new Task (experiment) with a new Task ID, but - the same project and Task names. + :param bool reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, + and the same project and Task name. .. note:: - Trains creates the new Task ID using the previous Id, which is stored in the data cache folder. + If the previously executed Task has artifacts or models, it will not be reused (overwritten) + and a new Task will be created. + When a Task is reused, the previous execution outputs are deleted, including console outputs and logs. The values are: - ``True`` - Reuse the last Task ID. (default) - ``False`` - Force a new Task (experiment). - - A string - In addition to a boolean, you can use a string to set a specific value for Task ID - (instead of the system generated UUID). + - A string - You can also specify a Task ID (string) to be reused, + instead of the cached ID based on the project/name combination. + + :param bool continue_last_task: Continue the execution of a previously executed Task (experiment) + + .. note:: + When continuing the executing of a previously executed Task, + all previous artifacts / models/ logs are intact. + New logs will continue iteration/step based on the previous-execution maximum iteration value. + For example: + The last train/loss scalar reported was iteration 100, the next report will be iteration 101. + + The values are: + + - ``True`` - Continue the the last Task ID. + specified explicitly by reuse_last_task_id or implicitly with the same logic as reuse_last_task_id + - ``False`` - Overwrite the execution of previous Task (default). + - A string - You can also specify a Task ID (string) to be continued. + This is equivalent to `continue_last_task=True` and `reuse_last_task_id=a_task_id_string`. :param str output_uri: The default location for output models and other artifacts. In the default location, Trains creates a subfolder for the output. The subfolder structure is the following: @@ -416,12 +436,14 @@ class Task(_Task): # if this is the main process, create the task if not is_sub_process_task_id: task = cls._create_dev_task( - project_name, - task_name, - task_type, - reuse_last_task_id, - detect_repo=False if (isinstance(auto_connect_frameworks, dict) and - not auto_connect_frameworks.get('detect_repository', True)) else True + default_project_name=project_name, + default_task_name=task_name, + default_task_type=task_type, + reuse_last_task_id=reuse_last_task_id, + continue_last_task=continue_last_task, + detect_repo=False if ( + isinstance(auto_connect_frameworks, dict) and + not auto_connect_frameworks.get('detect_repository', True)) else True ) # set defaults if output_uri: @@ -1667,7 +1689,8 @@ class Task(_Task): @classmethod def _create_dev_task( - cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, detect_repo=True + cls, default_project_name, default_task_name, default_task_type, + reuse_last_task_id, continue_last_task=False, detect_repo=True, ): if not default_project_name or not default_task_name: # get project name and task name from repository name and entry_point @@ -1686,8 +1709,13 @@ class Task(_Task): except Exception: pass + # conform reuse_last_task_id and continue_last_task + if continue_last_task and isinstance(continue_last_task, str): + reuse_last_task_id = continue_last_task + continue_last_task = True + # if we force no task reuse from os environment - if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id: + if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id or isinstance(reuse_last_task_id, str): default_task = None else: # if we have a previous session to use, get the task id from it @@ -1717,29 +1745,37 @@ class Task(_Task): task_id=default_task_id, log_to_backend=True, ) - task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags - task_artifacts = task.data.execution.artifacts \ - if hasattr(task.data.execution, 'artifacts') else None - if ((str(task._status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) - or task.output_model_id or (ARCHIVED_TAG in task_tags) - or (cls._development_tag not in task_tags) - or task_artifacts): - # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode - # If the task is archived, or already has an output model, - # we shouldn't use it in development mode either - default_task_id = None - task = None + + # instead of resting the previously used task we are continuing the training with it. + if task and continue_last_task: + task.reload() + task.mark_started(force=True) + task.set_initial_iteration(task.get_last_iteration()+1) else: - with task._edit_lock: - # from now on, there is no need to reload, we just clear stuff, - # this flag will be cleared off once we actually refresh at the end of the function - task._reload_skip_flag = True - # reset the task, so we can update it - task.reset(set_started_on_success=False, force=False) - # clear the heaviest stuff first - task._clear_task( - system_tags=[cls._development_tag], - comment=make_message('Auto-generated at %(time)s by %(user)s@%(host)s')) + task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags + task_artifacts = task.data.execution.artifacts \ + if hasattr(task.data.execution, 'artifacts') else None + if ((str(task._status) in ( + str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) + or task.output_model_id or (ARCHIVED_TAG in task_tags) + or (cls._development_tag not in task_tags) + or task_artifacts): + # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode + # If the task is archived, or already has an output model, + # we shouldn't use it in development mode either + default_task_id = None + task = None + else: + with task._edit_lock: + # from now on, there is no need to reload, we just clear stuff, + # this flag will be cleared off once we actually refresh at the end of the function + task._reload_skip_flag = True + # reset the task, so we can update it + task.reset(set_started_on_success=False, force=False) + # clear the heaviest stuff first + task._clear_task( + system_tags=[cls._development_tag], + comment=make_message('Auto-generated at %(time)s by %(user)s@%(host)s')) except (Exception, ValueError): # we failed reusing task, create a new one @@ -1779,8 +1815,11 @@ class Task(_Task): if closed_old_task: logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id'))) # print warning, reusing/creating a task - if default_task_id: + if default_task_id and not continue_last_task: logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id) + if default_task_id and continue_last_task: + logger.report_text('TRAINS Task: continuing previous task id=%s ' + 'Notice this run will not be reproducible!' % task.id) else: logger.report_text('TRAINS Task: created new task id=%s' % task.id)