From 6aa22e449ee0b6ddfceab36e589a0f8d3186767f Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 4 Oct 2019 01:31:57 +0300 Subject: [PATCH] Add alias for task type, train/inference --- trains/task.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/trains/task.py b/trains/task.py index 24dde7c1..c517c165 100644 --- a/trains/task.py +++ b/trains/task.py @@ -136,7 +136,8 @@ class Task(_Task): :param project_name: project to create the task in (if project doesn't exist, it will be created) :param task_name: task name to be created (in development mode, not when running remotely) - :param task_type: task type to be created (in development mode, not when running remotely) + :param task_type: task type to be created, Default: TaskTypes.training + Options are: 'testing', 'training' or 'train', 'inference' :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). if False every time we call the function we create a new task with the same name Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) @@ -181,6 +182,8 @@ class Task(_Task): # if this is a subprocess, regardless of what the init was called for, # we have to fix the main task hooks and stdout bindings if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): + if task_type is None: + task_type = cls.__main_task.task_type # make sure we only do it once per process cls.__forked_proc_main_pid = os.getpid() # make sure we do not wait for the repo detect thread @@ -222,6 +225,13 @@ class Task(_Task): # Backwards compatibility: if called from Task.current_task and task_type # was not specified, keep legacy default value of TaskTypes.training task_type = cls.TaskTypes.training + elif isinstance(task_type, six.string_types): + task_type_lookup = {'testing': cls.TaskTypes.testing, 'inference': cls.TaskTypes.testing, + 'train': cls.TaskTypes.training, 'training': cls.TaskTypes.training,} + if task_type not in task_type_lookup: + raise ValueError("Task type '{}' not supported, options are: {}".format(task_type, + list(task_type_lookup.keys()))) + task_type = task_type_lookup[task_type] try: if not running_remotely():