Support Task get_model_config/set_model_config legacy model configuration interface

This commit is contained in:
allegroai 2020-08-08 12:38:46 +03:00
parent f4f53902ed
commit ef83a648eb
2 changed files with 27 additions and 12 deletions

View File

@ -59,6 +59,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__'
_development_tag = 'development'
_default_configuration_section_name = 'General'
_force_requirements = {}
_store_diff = config.get('development.store_uncommitted_code_diff', False)
@ -890,7 +891,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
for k, v in parameters.items():
org_k = k
if '/' not in k:
k = 'General/{}'.format(k)
k = '{}/{}'.format(self._default_configuration_section_name, k)
section_name, key = k.split('/', 1)
section = hyperparams.get(section_name, dict())
description = \
@ -1021,12 +1022,19 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# type: (str) -> ()
with self._edit_lock:
self.reload()
execution = self.data.execution
if design is not None:
# noinspection PyProtectedMember
execution.model_desc = Model._wrap_design(design)
if Session.check_min_api_version('2.9'):
configuration = self._get_task_property(
"configuration", default={}, raise_on_error=False, log_on_error=False) or {}
configuration[self._default_configuration_section_name] = tasks.ConfigurationItem(
name=self._default_configuration_section_name, value=str(design))
self._edit(configuration=configuration)
else:
execution = self.data.execution
if design is not None:
# noinspection PyProtectedMember
execution.model_desc = Model._wrap_design(design)
self._edit(execution=execution)
self._edit(execution=execution)
def get_labels_enumeration(self):
# type: () -> Mapping[str, int]
@ -1046,7 +1054,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: The model configuration as blob of text.
"""
design = self._get_task_property("execution.model_desc", default={}, raise_on_error=False, log_on_error=False)
if Session.check_min_api_version('2.9'):
design = self._get_task_property(
"configuration", default={}, raise_on_error=False, log_on_error=False) or {}
if design:
design = design.get(sorted(design.keys())[0]).value or ''
else:
design = self._get_task_property(
"execution.model_desc", default={}, raise_on_error=False, log_on_error=False)
# noinspection PyProtectedMember
return Model._unwrap_design(design)

View File

@ -124,7 +124,6 @@ class Task(_Task):
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
__default_output_uri = config.get('development.default_output_uri', None)
__default_configuration_name = 'General'
class _ConnectedParametersType(object):
argparse = "argument_parser"
@ -940,9 +939,9 @@ class Task(_Task):
multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name:
name = self.__default_configuration_name
name = self._default_configuration_section_name
if not multi_config_support and name and name != self.__default_configuration_name:
if not multi_config_support and name and name != self._default_configuration_section_name:
raise ValueError("Multiple configurations is not supported with the current 'trains-server', "
"please upgrade to the latest version")
@ -1003,9 +1002,9 @@ class Task(_Task):
multi_config_support = Session.check_min_api_version('2.9')
if multi_config_support and not name:
name = self.__default_configuration_name
name = self._default_configuration_section_name
if not multi_config_support and name and name != self.__default_configuration_name:
if not multi_config_support and name and name != self._default_configuration_section_name:
raise ValueError("Multiple configurations is not supported with the current 'trains-server', "
"please upgrade to the latest version")