Add Task._is_remote_main_task to support secondary copy of the main Task in remote run

This commit is contained in:
allegroai 2020-11-20 00:06:43 +02:00
parent 91cbc161f4
commit 21ef615bb1
3 changed files with 18 additions and 9 deletions

View File

@ -1546,6 +1546,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._app_server = Session.get_app_server_host()
return self._app_server
def _is_remote_main_task(self):
# type: () -> bool
"""
:return: return True if running remotely and this Task is the registered main task
"""
return running_remotely() and get_remote_task_id() == self.id
def _edit(self, **kwargs):
# type: (**Any) -> Any
with self._edit_lock:

View File

@ -747,7 +747,8 @@ class InputModel(Model):
"""
self._set_task(task)
if running_remotely() and task.input_model and task.is_main_task():
# noinspection PyProtectedMember
if running_remotely() and task.input_model and (task.is_main_task() or task._is_remote_main_task()):
self._base_model = task.input_model
self._base_model_id = task.input_model.id
else:
@ -996,7 +997,8 @@ class OutputModel(BaseModel):
if self._task != task:
raise ValueError('Can only connect preexisting model to task, but this is a fresh model')
if running_remotely() and task.is_main_task():
# noinspection PyProtectedMember
if running_remotely() and (task.is_main_task() or task._is_remote_main_task()):
if self._floating_data:
# noinspection PyProtectedMember
self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \

View File

@ -1048,7 +1048,7 @@ class Task(_Task):
# noinspection PyProtectedMember
task._set_model_config(config_dict=config_dict)
if not running_remotely() or not self.is_main_task():
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
if multi_config_support:
self._set_configuration(
name=name, description=description, config_type='dictionary', config_dict=configuration)
@ -1075,7 +1075,7 @@ class Task(_Task):
return configuration
# it is a path to a local file
if not running_remotely() or not self.is_main_task():
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
# check if not absolute path
configuration_path = Path(configuration)
if not configuration_path.is_file():
@ -1132,7 +1132,7 @@ class Task(_Task):
raise ValueError("connect_label_enumeration supports only `dict` type, "
"{} is not supported".format(type(enumeration)))
if not running_remotely() or not self.is_main_task():
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
self.set_model_label_enumeration(enumeration)
else:
# pop everything
@ -2402,7 +2402,7 @@ class Task(_Task):
if parsed_args is None and parser == _parser:
parsed_args = _parsed_args
if running_remotely() and (self.is_main_task() or self.id == get_remote_task_id()):
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
self._arguments.copy_to_parser(parser, parsed_args)
else:
self._arguments.copy_defaults_from_argparse(
@ -2425,7 +2425,7 @@ class Task(_Task):
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if not running_remotely() or not self.is_main_task():
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
else:
@ -2439,7 +2439,7 @@ class Task(_Task):
def _connect_task_parameters(self, attr_class, name=None):
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
if running_remotely() and self.is_main_task():
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
parameters = self.get_parameters()
if not name:
attr_class.update_from_dict(parameters)
@ -2462,7 +2462,7 @@ class Task(_Task):
return False
a_dict = {k: v for k, v in an_object.__dict__.items() if verify_type(k, v)}
if running_remotely() and self.is_main_task():
if running_remotely() and (self.is_main_task() or self._is_remote_main_task()):
a_dict = self._connect_dictionary(a_dict, name)
for k, v in a_dict.items():
if getattr(an_object, k, None) != a_dict[k]: