Fix server was not updated with the defaults from the code when running remotely and configuration section is missing

This commit is contained in:
allegroai 2020-12-06 11:26:32 +02:00
parent 5e70a9e6eb
commit fda00e2e1b

View File

@ -1001,7 +1001,7 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def connect_configuration(self, configuration, name=None, description=None): def connect_configuration(self, configuration, name=None, description=None):
# type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str] # type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[dict, Path, str]
""" """
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object. Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
This method should be called before reading the configuration file. This method should be called before reading the configuration file.
@ -1087,6 +1087,11 @@ class Task(_Task):
LoggerRoot.get_base_logger().warning( LoggerRoot.get_base_logger().warning(
"Could not retrieve remote configuration named \'{}\'\n" "Could not retrieve remote configuration named \'{}\'\n"
"Using default configuration: {}".format(name, str(configuration))) "Using default configuration: {}".format(name, str(configuration)))
# update back configuration section
if multi_config_support:
self._set_configuration(
name=name, description=description,
config_type='dictionary', config_dict=configuration)
return configuration return configuration
configuration.clear() configuration.clear()
@ -1118,6 +1123,24 @@ class Task(_Task):
else: else:
configuration_text = self._get_configuration_text(name=name) if multi_config_support \ configuration_text = self._get_configuration_text(name=name) if multi_config_support \
else self._get_model_config_text() else self._get_model_config_text()
if configuration_text is None:
LoggerRoot.get_base_logger().warning(
"Could not retrieve remote configuration named \'{}\'\n"
"Using default configuration: {}".format(name, str(configuration)))
# update back configuration section
if multi_config_support:
configuration_path = Path(configuration)
if configuration_path.is_file():
with open(configuration_path.as_posix(), 'rt') as f:
configuration_text = f.read()
self._set_configuration(
name=name, description=description,
config_type=configuration_path.suffixes[-1].lstrip('.')
if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file',
config_text=configuration_text)
return configuration
configuration_path = Path(configuration) configuration_path = Path(configuration)
fd, local_filename = mkstemp(prefix='trains_task_config_', fd, local_filename = mkstemp(prefix='trains_task_config_',
suffix=configuration_path.suffixes[-1] if suffix=configuration_path.suffixes[-1] if