Add alias for task type, train/inference

This commit is contained in:
allegroai 2019-10-04 01:31:57 +03:00
parent 4a42345561
commit 6aa22e449e

View File

@ -136,7 +136,8 @@ class Task(_Task):
:param project_name: project to create the task in (if project doesn't exist, it will be created) :param project_name: project to create the task in (if project doesn't exist, it will be created)
:param task_name: task name to be created (in development mode, not when running remotely) :param task_name: task name to be created (in development mode, not when running remotely)
:param task_type: task type to be created (in development mode, not when running remotely) :param task_type: task type to be created, Default: TaskTypes.training
Options are: 'testing', 'training' or 'train', 'inference'
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
if False every time we call the function we create a new task with the same name if False every time we call the function we create a new task with the same name
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
@ -181,6 +182,8 @@ class Task(_Task):
# if this is a subprocess, regardless of what the init was called for, # if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings # we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
if task_type is None:
task_type = cls.__main_task.task_type
# make sure we only do it once per process # make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid() cls.__forked_proc_main_pid = os.getpid()
# make sure we do not wait for the repo detect thread # make sure we do not wait for the repo detect thread
@ -222,6 +225,13 @@ class Task(_Task):
# Backwards compatibility: if called from Task.current_task and task_type # Backwards compatibility: if called from Task.current_task and task_type
# was not specified, keep legacy default value of TaskTypes.training # was not specified, keep legacy default value of TaskTypes.training
task_type = cls.TaskTypes.training task_type = cls.TaskTypes.training
elif isinstance(task_type, six.string_types):
task_type_lookup = {'testing': cls.TaskTypes.testing, 'inference': cls.TaskTypes.testing,
'train': cls.TaskTypes.training, 'training': cls.TaskTypes.training,}
if task_type not in task_type_lookup:
raise ValueError("Task type '{}' not supported, options are: {}".format(task_type,
list(task_type_lookup.keys())))
task_type = task_type_lookup[task_type]
try: try:
if not running_remotely(): if not running_remotely():