Add Task.init() argument continue_last_task to continue a previously used Task (issue #160)

This commit is contained in:
allegroai 2020-07-14 23:36:03 +03:00
parent 876513d195
commit 925450c7ef
2 changed files with 78 additions and 38 deletions

View File

@ -238,7 +238,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# check latest version # check latest version
from ...utilities.check_updates import CheckPackageUpdates from ...utilities.check_updates import CheckPackageUpdates
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True) latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version and not SUPPRESS_UPDATE_MESSAGE_ENV_VAR.get(default=config.get('development.suppress_update_message', False)): if latest_version and not SUPPRESS_UPDATE_MESSAGE_ENV_VAR.get(
default=config.get('development.suppress_update_message', False)):
if not latest_version[1]: if not latest_version[1]:
sep = os.linesep sep = os.linesep
self.get_logger().report_text( self.get_logger().report_text(

View File

@ -175,7 +175,8 @@ class Task(_Task):
project_name=None, # type: Optional[str] project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str] task_name=None, # type: Optional[str]
task_type=TaskTypes.training, # type: Task.TaskTypes task_type=TaskTypes.training, # type: Task.TaskTypes
reuse_last_task_id=True, # type: bool reuse_last_task_id=True, # type: Union[bool, str]
continue_last_task=False, # type: Union[bool, str]
output_uri=None, # type: Optional[str] output_uri=None, # type: Optional[str]
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]] auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
@ -241,18 +242,37 @@ class Task(_Task):
- ``TaskTypes.qc`` - ``TaskTypes.qc``
- ``TaskTypes.custom`` - ``TaskTypes.custom``
:param bool reuse_last_task_id: Force a new Task (experiment) with a new Task ID, but :param bool reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID,
the same project and Task names. and the same project and Task name.
.. note:: .. note::
Trains creates the new Task ID using the previous Id, which is stored in the data cache folder. If the previously executed Task has artifacts or models, it will not be reused (overwritten)
and a new Task will be created.
When a Task is reused, the previous execution outputs are deleted, including console outputs and logs.
The values are: The values are:
- ``True`` - Reuse the last Task ID. (default) - ``True`` - Reuse the last Task ID. (default)
- ``False`` - Force a new Task (experiment). - ``False`` - Force a new Task (experiment).
- A string - In addition to a boolean, you can use a string to set a specific value for Task ID - A string - You can also specify a Task ID (string) to be reused,
(instead of the system generated UUID). instead of the cached ID based on the project/name combination.
:param bool continue_last_task: Continue the execution of a previously executed Task (experiment)
.. note::
When continuing the executing of a previously executed Task,
all previous artifacts / models/ logs are intact.
New logs will continue iteration/step based on the previous-execution maximum iteration value.
For example:
The last train/loss scalar reported was iteration 100, the next report will be iteration 101.
The values are:
- ``True`` - Continue the the last Task ID.
specified explicitly by reuse_last_task_id or implicitly with the same logic as reuse_last_task_id
- ``False`` - Overwrite the execution of previous Task (default).
- A string - You can also specify a Task ID (string) to be continued.
This is equivalent to `continue_last_task=True` and `reuse_last_task_id=a_task_id_string`.
:param str output_uri: The default location for output models and other artifacts. In the default location, :param str output_uri: The default location for output models and other artifacts. In the default location,
Trains creates a subfolder for the output. The subfolder structure is the following: Trains creates a subfolder for the output. The subfolder structure is the following:
@ -416,12 +436,14 @@ class Task(_Task):
# if this is the main process, create the task # if this is the main process, create the task
if not is_sub_process_task_id: if not is_sub_process_task_id:
task = cls._create_dev_task( task = cls._create_dev_task(
project_name, default_project_name=project_name,
task_name, default_task_name=task_name,
task_type, default_task_type=task_type,
reuse_last_task_id, reuse_last_task_id=reuse_last_task_id,
detect_repo=False if (isinstance(auto_connect_frameworks, dict) and continue_last_task=continue_last_task,
not auto_connect_frameworks.get('detect_repository', True)) else True detect_repo=False if (
isinstance(auto_connect_frameworks, dict) and
not auto_connect_frameworks.get('detect_repository', True)) else True
) )
# set defaults # set defaults
if output_uri: if output_uri:
@ -1667,7 +1689,8 @@ class Task(_Task):
@classmethod @classmethod
def _create_dev_task( def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id, detect_repo=True cls, default_project_name, default_task_name, default_task_type,
reuse_last_task_id, continue_last_task=False, detect_repo=True,
): ):
if not default_project_name or not default_task_name: if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point # get project name and task name from repository name and entry_point
@ -1686,8 +1709,13 @@ class Task(_Task):
except Exception: except Exception:
pass pass
# conform reuse_last_task_id and continue_last_task
if continue_last_task and isinstance(continue_last_task, str):
reuse_last_task_id = continue_last_task
continue_last_task = True
# if we force no task reuse from os environment # if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id: if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id or isinstance(reuse_last_task_id, str):
default_task = None default_task = None
else: else:
# if we have a previous session to use, get the task id from it # if we have a previous session to use, get the task id from it
@ -1717,29 +1745,37 @@ class Task(_Task):
task_id=default_task_id, task_id=default_task_id,
log_to_backend=True, log_to_backend=True,
) )
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
task_artifacts = task.data.execution.artifacts \ # instead of resting the previously used task we are continuing the training with it.
if hasattr(task.data.execution, 'artifacts') else None if task and continue_last_task:
if ((str(task._status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) task.reload()
or task.output_model_id or (ARCHIVED_TAG in task_tags) task.mark_started(force=True)
or (cls._development_tag not in task_tags) task.set_initial_iteration(task.get_last_iteration()+1)
or task_artifacts):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
default_task_id = None
task = None
else: else:
with task._edit_lock: task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
# from now on, there is no need to reload, we just clear stuff, task_artifacts = task.data.execution.artifacts \
# this flag will be cleared off once we actually refresh at the end of the function if hasattr(task.data.execution, 'artifacts') else None
task._reload_skip_flag = True if ((str(task._status) in (
# reset the task, so we can update it str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
task.reset(set_started_on_success=False, force=False) or task.output_model_id or (ARCHIVED_TAG in task_tags)
# clear the heaviest stuff first or (cls._development_tag not in task_tags)
task._clear_task( or task_artifacts):
system_tags=[cls._development_tag], # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
comment=make_message('Auto-generated at %(time)s by %(user)s@%(host)s')) # If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
default_task_id = None
task = None
else:
with task._edit_lock:
# from now on, there is no need to reload, we just clear stuff,
# this flag will be cleared off once we actually refresh at the end of the function
task._reload_skip_flag = True
# reset the task, so we can update it
task.reset(set_started_on_success=False, force=False)
# clear the heaviest stuff first
task._clear_task(
system_tags=[cls._development_tag],
comment=make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
except (Exception, ValueError): except (Exception, ValueError):
# we failed reusing task, create a new one # we failed reusing task, create a new one
@ -1779,8 +1815,11 @@ class Task(_Task):
if closed_old_task: if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id'))) logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task # print warning, reusing/creating a task
if default_task_id: if default_task_id and not continue_last_task:
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id) logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
if default_task_id and continue_last_task:
logger.report_text('TRAINS Task: continuing previous task id=%s '
'Notice this run will not be reproducible!' % task.id)
else: else:
logger.report_text('TRAINS Task: created new task id=%s' % task.id) logger.report_text('TRAINS Task: created new task id=%s' % task.id)