Allow overriding initial iteration offset using environment variable (CLEARML_SET_ITERATION_OFFSET) or Task.init(continue_last_task==<offset>) (issue #496)

This commit is contained in:
allegroai 2021-11-30 21:15:03 +02:00
parent 297f33703f
commit fc0305728c

View File

@ -27,7 +27,7 @@ from pathlib2 import Path
from .backend_config.defs import get_active_config_file, get_config_file from .backend_config.defs import get_active_config_file, get_config_file
from .backend_api.services import tasks, projects from .backend_api.services import tasks, projects
from .backend_api.session.session import ( from .backend_api.session.session import (
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST) Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
from .backend_interface.metrics import Metrics from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
@ -51,7 +51,7 @@ from .binding.hydra_bind import PatchHydra
from .binding.click_bind import PatchClick from .binding.click_bind import PatchClick
from .config import ( from .config import (
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI, config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
deferred_config, ) deferred_config, TASK_SET_ITERATION_OFFSET, )
from .config import running_remotely, get_remote_task_id from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache from .config.cache import SessionCache
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
@ -60,11 +60,13 @@ from .logger import Logger
from .model import Model, InputModel, OutputModel from .model import Model, InputModel, OutputModel
from .task_parameters import TaskParameters from .task_parameters import TaskParameters
from .utilities.config import verify_basic_value from .utilities.config import verify_basic_value
from .binding.args import argparser_parseargs_called, get_argparser_last_args, \ from .binding.args import (
argparser_update_currenttask argparser_parseargs_called, get_argparser_last_args,
argparser_update_currenttask, )
from .utilities.dicts import ReadOnlyDict, merge_dicts from .utilities.dicts import ReadOnlyDict, merge_dicts
from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \ from .utilities.proxy_object import (
nested_from_flat_dictionary, naive_nested_from_flat_dictionary ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary,
nested_from_flat_dictionary, naive_nested_from_flat_dictionary, )
from .utilities.resource_monitor import ResourceMonitor from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
from .utilities.lowlevel.threads import get_current_thread_id from .utilities.lowlevel.threads import get_current_thread_id
@ -72,7 +74,6 @@ from .utilities.process.mp import BackgroundMonitor, leave_process
# noinspection PyProtectedMember # noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments from .backend_interface.task.args import _Arguments
if TYPE_CHECKING: if TYPE_CHECKING:
import pandas import pandas
import numpy import numpy
@ -193,18 +194,18 @@ class Task(_Task):
@classmethod @classmethod
def init( def init(
cls, cls,
project_name=None, # type: Optional[str] project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str] task_name=None, # type: Optional[str]
task_type=TaskTypes.training, # type: Task.TaskTypes task_type=TaskTypes.training, # type: Task.TaskTypes
tags=None, # type: Optional[Sequence[str]] tags=None, # type: Optional[Sequence[str]]
reuse_last_task_id=True, # type: Union[bool, str] reuse_last_task_id=True, # type: Union[bool, str]
continue_last_task=False, # type: Union[bool, str] continue_last_task=False, # type: Union[bool, str, int]
output_uri=None, # type: Optional[Union[str, bool]] output_uri=None, # type: Optional[Union[str, bool]]
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]] auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_resource_monitoring=True, # type: bool auto_resource_monitoring=True, # type: bool
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]] auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
): ):
# type: (...) -> "Task" # type: (...) -> "Task"
""" """
@ -297,6 +298,9 @@ class Task(_Task):
- ``False`` - Overwrite the execution of previous Task (default). - ``False`` - Overwrite the execution of previous Task (default).
- A string - You can also specify a Task ID (string) to be continued. - A string - You can also specify a Task ID (string) to be continued.
This is equivalent to `continue_last_task=True` and `reuse_last_task_id=a_task_id_string`. This is equivalent to `continue_last_task=True` and `reuse_last_task_id=a_task_id_string`.
- An integer - Specify initial iteration offset (override the auto automatic last_iteration_offset)
Pass 0, to disable the automatic last_iteration_offset or specify a different initial offset
You can specify a Task ID to be used with `reuse_last_task_id='task_id_here'`
:param str output_uri: The default location for output models and other artifacts. :param str output_uri: The default location for output models and other artifacts.
If True is passed, the default files_server will be used for model storage. If True is passed, the default files_server will be used for model storage.
@ -641,23 +645,23 @@ class Task(_Task):
@classmethod @classmethod
def create( def create(
cls, cls,
project_name=None, # type: Optional[str] project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str] task_name=None, # type: Optional[str]
task_type=None, # type: Optional[str] task_type=None, # type: Optional[str]
repo=None, # type: Optional[str] repo=None, # type: Optional[str]
branch=None, # type: Optional[str] branch=None, # type: Optional[str]
commit=None, # type: Optional[str] commit=None, # type: Optional[str]
script=None, # type: Optional[str] script=None, # type: Optional[str]
working_directory=None, # type: Optional[str] working_directory=None, # type: Optional[str]
packages=None, # type: Optional[Union[bool, Sequence[str]]] packages=None, # type: Optional[Union[bool, Sequence[str]]]
requirements_file=None, # type: Optional[Union[str, Path]] requirements_file=None, # type: Optional[Union[str, Path]]
docker=None, # type: Optional[str] docker=None, # type: Optional[str]
docker_args=None, # type: Optional[str] docker_args=None, # type: Optional[str]
docker_bash_setup_script=None, # type: Optional[str] docker_bash_setup_script=None, # type: Optional[str]
argparse_args=None, # type: Optional[Sequence[Tuple[str, str]]] argparse_args=None, # type: Optional[Sequence[Tuple[str, str]]]
base_task_id=None, # type: Optional[str] base_task_id=None, # type: Optional[str]
add_task_init_call=True, # type: bool add_task_init_call=True, # type: bool
): ):
# type: (...) -> Task # type: (...) -> Task
""" """
@ -1619,14 +1623,14 @@ class Task(_Task):
return self._artifacts_manager.registered_artifacts return self._artifacts_manager.registered_artifacts
def upload_artifact( def upload_artifact(
self, self,
name, # type: str name, # type: str
artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image, Any] artifact_object, # type: Union[str, Mapping, pandas.DataFrame, numpy.ndarray, Image.Image, Any]
metadata=None, # type: Optional[Mapping] metadata=None, # type: Optional[Mapping]
delete_after_upload=False, # type: bool delete_after_upload=False, # type: bool
auto_pickle=True, # type: bool auto_pickle=True, # type: bool
preview=None, # type: Any preview=None, # type: Any
wait_on_upload=False, # type: bool wait_on_upload=False, # type: bool
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -2138,7 +2142,7 @@ class Task(_Task):
"Task enqueuing itself must exit the process afterwards.") "Task enqueuing itself must exit the process afterwards.")
# make sure we analyze the process # make sure we analyze the process
if self.status in (Task.TaskStatusEnum.in_progress, ): if self.status in (Task.TaskStatusEnum.in_progress,):
if clone: if clone:
# wait for repository detection (5 minutes should be reasonable time to detect all packages) # wait for repository detection (5 minutes should be reasonable time to detect all packages)
self.flush(wait_for_uploads=True) self.flush(wait_for_uploads=True)
@ -2638,8 +2642,8 @@ class Task(_Task):
@classmethod @classmethod
def _create_dev_task( def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type, tags, cls, default_project_name, default_task_name, default_task_type, tags,
reuse_last_task_id, continue_last_task=False, detect_repo=True, auto_connect_streams=True reuse_last_task_id, continue_last_task=False, detect_repo=True, auto_connect_streams=True
): ):
if not default_project_name or not default_task_name: if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point # get project name and task name from repository name and entry_point
@ -2662,6 +2666,12 @@ class Task(_Task):
if continue_last_task and isinstance(continue_last_task, str): if continue_last_task and isinstance(continue_last_task, str):
reuse_last_task_id = continue_last_task reuse_last_task_id = continue_last_task
continue_last_task = True continue_last_task = True
elif isinstance(continue_last_task, int) and continue_last_task is not True:
# allow initial offset environment override
continue_last_task = continue_last_task
if TASK_SET_ITERATION_OFFSET.get() is not None:
continue_last_task = TASK_SET_ITERATION_OFFSET.get()
# if we force no task reuse from os environment # if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id or isinstance(reuse_last_task_id, str): if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id or isinstance(reuse_last_task_id, str):
@ -2696,10 +2706,15 @@ class Task(_Task):
) )
# instead of resting the previously used task we are continuing the training with it. # instead of resting the previously used task we are continuing the training with it.
if task and continue_last_task: if task and (continue_last_task or isinstance(continue_last_task, int)):
task.reload() task.reload()
task.mark_started(force=True) task.mark_started(force=True)
task.set_initial_iteration(task.get_last_iteration()+1) # allow to disable the
if continue_last_task is True:
task.set_initial_iteration(task.get_last_iteration() + 1)
else:
task.set_initial_iteration(continue_last_task)
else: else:
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
task_artifacts = task.data.execution.artifacts \ task_artifacts = task.data.execution.artifacts \
@ -2782,10 +2797,10 @@ class Task(_Task):
import traceback import traceback
stack = traceback.extract_stack(limit=10) stack = traceback.extract_stack(limit=10)
# NOTICE WE ARE ALWAYS 3 down from caller in stack! # NOTICE WE ARE ALWAYS 3 down from caller in stack!
for i in range(len(stack)-1, 0, -1): for i in range(len(stack) - 1, 0, -1):
# look for the Task.init call, then the one above it is the callee module # look for the Task.init call, then the one above it is the callee module
if stack[i].name == 'init': if stack[i].name == 'init':
task._calling_filename = os.path.abspath(stack[i-1].filename) task._calling_filename = os.path.abspath(stack[i - 1].filename)
break break
except Exception: except Exception:
pass pass
@ -2818,7 +2833,7 @@ class Task(_Task):
if not self._logger: if not self._logger:
# do not recreate logger after task was closed/quit # do not recreate logger after task was closed/quit
if self._at_exit_called and self._at_exit_called in (True, get_current_thread_id(), ): if self._at_exit_called and self._at_exit_called in (True, get_current_thread_id(),):
raise ValueError("Cannot use Task Logger after task was closed") raise ValueError("Cannot use Task Logger after task was closed")
# Get a logger object # Get a logger object
self._logger = Logger( self._logger = Logger(
@ -2955,7 +2970,7 @@ class Task(_Task):
attr_class.update_from_dict(parameters) attr_class.update_from_dict(parameters)
else: else:
attr_class.update_from_dict( attr_class.update_from_dict(
dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name)))) dict((k[len(name) + 1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
else: else:
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name) self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
return attr_class return attr_class
@ -3153,7 +3168,7 @@ class Task(_Task):
if (not running_remotely() or DEBUG_SIMULATE_REMOTE_TASK.get()) \ if (not running_remotely() or DEBUG_SIMULATE_REMOTE_TASK.get()) \
and self.is_main_task() and not is_sub_process: and self.is_main_task() and not is_sub_process:
# check if we crashed, ot the signal is not interrupt (manual break) # check if we crashed, ot the signal is not interrupt (manual break)
task_status = ('stopped', ) task_status = ('stopped',)
if self.__exit_hook: if self.__exit_hook:
is_exception = self.__exit_hook.exception is_exception = self.__exit_hook.exception
# check if we are running inside a debugger # check if we are running inside a debugger
@ -3182,9 +3197,9 @@ class Task(_Task):
wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None) wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None)
if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None and \ if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None and \
not is_exception: not is_exception:
task_status = ('completed', ) task_status = ('completed',)
else: else:
task_status = ('stopped', ) task_status = ('stopped',)
# user aborted. do not bother flushing the stdout logs # user aborted. do not bother flushing the stdout logs
wait_for_std_log = self.__exit_hook.signal is not None wait_for_std_log = self.__exit_hook.signal is not None