Make sure Task.connect() returns the same value it is passed

This commit is contained in:
allegroai 2020-02-18 11:26:52 +02:00
parent 14588e6dec
commit b5168010e9

View File

@ -1081,6 +1081,7 @@ class Task(_Task):
def _connect_output_model(self, model): def _connect_output_model(self, model):
assert isinstance(model, OutputModel) assert isinstance(model, OutputModel)
model.connect(self) model.connect(self)
return model
def _save_output_model(self, model): def _save_output_model(self, model):
""" """
@ -1117,6 +1118,7 @@ class Task(_Task):
return return
self._last_input_model_id = model.id self._last_input_model_id = model.id
model.connect(self) model.connect(self)
return model
def _try_set_connected_parameter_type(self, option): def _try_set_connected_parameter_type(self, option):
# """ Raise an error if current value is not None and not equal to the provided option value """ # """ Raise an error if current value is not None and not equal to the provided option value """
@ -1146,7 +1148,7 @@ class Task(_Task):
from IPython import get_ipython from IPython import get_ipython
ip = get_ipython() ip = get_ipython()
if ip is not None and 'IPKernelApp' in ip.config: if ip is not None and 'IPKernelApp' in ip.config:
return return parser
except Exception: except Exception:
pass pass
@ -1168,6 +1170,7 @@ class Task(_Task):
self._arguments.copy_to_parser(parser, parsed_args) self._arguments.copy_to_parser(parser, parsed_args)
else: else:
self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args) self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args)
return parser
def _connect_dictionary(self, dictionary): def _connect_dictionary(self, dictionary):
def _update_args_dict(task, config_dict): def _update_args_dict(task, config_dict):
@ -1200,6 +1203,7 @@ class Task(_Task):
attr_class.update_from_dict(self.get_parameters()) attr_class.update_from_dict(self.get_parameters())
else: else:
self.set_parameters(attr_class.to_dict()) self.set_parameters(attr_class.to_dict())
return attr_class
def _validate(self, check_output_dest_credentials=False): def _validate(self, check_output_dest_credentials=False):
if running_remotely(): if running_remotely():