diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 40e425fd..06fd40ca 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -1546,6 +1546,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._app_server = Session.get_app_server_host() return self._app_server + def _is_remote_main_task(self): + # type: () -> bool + """ + :return: return True if running remotely and this Task is the registered main task + """ + return running_remotely() and get_remote_task_id() == self.id + def _edit(self, **kwargs): # type: (**Any) -> Any with self._edit_lock: diff --git a/trains/model.py b/trains/model.py index 396cbabe..d774394a 100644 --- a/trains/model.py +++ b/trains/model.py @@ -747,7 +747,8 @@ class InputModel(Model): """ self._set_task(task) - if running_remotely() and task.input_model and task.is_main_task(): + # noinspection PyProtectedMember + if running_remotely() and task.input_model and (task.is_main_task() or task._is_remote_main_task()): self._base_model = task.input_model self._base_model_id = task.input_model.id else: @@ -996,7 +997,8 @@ class OutputModel(BaseModel): if self._task != task: raise ValueError('Can only connect preexisting model to task, but this is a fresh model') - if running_remotely() and task.is_main_task(): + # noinspection PyProtectedMember + if running_remotely() and (task.is_main_task() or task._is_remote_main_task()): if self._floating_data: # noinspection PyProtectedMember self._floating_data.design = _Model._wrap_design(self._task._get_model_config_text()) or \ diff --git a/trains/task.py b/trains/task.py index 8ba692b3..027a36b3 100644 --- a/trains/task.py +++ b/trains/task.py @@ -1048,7 +1048,7 @@ class Task(_Task): # noinspection PyProtectedMember task._set_model_config(config_dict=config_dict) - if not running_remotely() or not self.is_main_task(): + if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if multi_config_support: self._set_configuration( name=name, description=description, config_type='dictionary', config_dict=configuration) @@ -1075,7 +1075,7 @@ class Task(_Task): return configuration # it is a path to a local file - if not running_remotely() or not self.is_main_task(): + if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): # check if not absolute path configuration_path = Path(configuration) if not configuration_path.is_file(): @@ -1132,7 +1132,7 @@ class Task(_Task): raise ValueError("connect_label_enumeration supports only `dict` type, " "{} is not supported".format(type(enumeration))) - if not running_remotely() or not self.is_main_task(): + if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): self.set_model_label_enumeration(enumeration) else: # pop everything @@ -2402,7 +2402,7 @@ class Task(_Task): if parsed_args is None and parser == _parser: parsed_args = _parsed_args - if running_remotely() and (self.is_main_task() or self.id == get_remote_task_id()): + if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): self._arguments.copy_to_parser(parser, parsed_args) else: self._arguments.copy_defaults_from_argparse( @@ -2425,7 +2425,7 @@ class Task(_Task): self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) - if not running_remotely() or not self.is_main_task(): + if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name) dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) else: @@ -2439,7 +2439,7 @@ class Task(_Task): def _connect_task_parameters(self, attr_class, name=None): self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters) - if running_remotely() and self.is_main_task(): + if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): parameters = self.get_parameters() if not name: attr_class.update_from_dict(parameters) @@ -2462,7 +2462,7 @@ class Task(_Task): return False a_dict = {k: v for k, v in an_object.__dict__.items() if verify_type(k, v)} - if running_remotely() and self.is_main_task(): + if running_remotely() and (self.is_main_task() or self._is_remote_main_task()): a_dict = self._connect_dictionary(a_dict, name) for k, v in a_dict.items(): if getattr(an_object, k, None) != a_dict[k]: