From 65003a168a3e0489d2843df027b065d8087ab024 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 8 Aug 2020 12:48:23 +0300 Subject: [PATCH] Fix hyper-parameter legacy mode (type=='legacy') Add type/description to TF_DEFINES Cast hyper-parameters to string (if not None) --- trains/backend_interface/task/args.py | 70 +++++++++++++++++------ trains/backend_interface/task/task.py | 82 +++++++++++++++------------ trains/binding/absl_bind.py | 22 +++++-- 3 files changed, 117 insertions(+), 57 deletions(-) diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index 6465acfc..6a6235ef 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -4,6 +4,7 @@ from six import PY2 from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS from copy import copy +from ...backend_api import Session from ...utilities.args import call_original_argparser @@ -59,21 +60,24 @@ class _Arguments(object): self._exclude_parser_args = excluded_args or {} def set_defaults(self, *dicts, **kwargs): + prefix = self._prefix_args if Session.check_min_api_version('2.9') else None # noinspection PyProtectedMember - self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs) + self._task._set_parameters(*dicts, __parameters_prefix=prefix, **kwargs) def add_argument(self, option_strings, type=None, default=None, help=None): if not option_strings: raise Exception('Expected at least one argument name (option string)') name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t') - name = self._prefix_args + name + if Session.check_min_api_version('2.9'): + name = self._prefix_args + name self._task.set_parameter(name=name, value=default, description=help) def connect(self, parser): self._task.connect_argparse(parser) @classmethod - def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None): + def _add_to_defaults(cls, a_parser, defaults, descriptions, arg_types, + a_args=None, a_namespace=None, a_parsed_args=None): actions = [ a for a in a_parser._actions if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction) @@ -100,6 +104,8 @@ class _Arguments(object): desc_ = {a.dest: a.help for a in actions} descriptions.update(desc_) + types_ = {a.dest: (a.type or None) for a in actions} + arg_types.update(types_) full_args_dict = copy(defaults) full_args_dict.update(args_dict) @@ -115,19 +121,23 @@ class _Arguments(object): defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or '' for choice in sub_parser.choices.values(): # recursively parse - defaults, descriptions = cls._add_to_defaults( + defaults, descriptions, arg_types = cls._add_to_defaults( a_parser=choice, defaults=defaults, descriptions=descriptions, + arg_types=arg_types, a_parsed_args=a_parsed_args or full_args_dict ) - return defaults, descriptions + return defaults, descriptions, arg_types def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None): task_defaults = {} task_defaults_descriptions = {} - self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args) + task_defaults_types = {} + self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, task_defaults_types, + args, namespace, parsed_args) + # Make sure we didn't miss anything if parsed_args: @@ -150,12 +160,23 @@ class _Arguments(object): del task_defaults[k] # Skip excluded arguments, Add prefix. - task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() - if self._exclude_parser_args.get(k, True)]) - task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items() - if self._exclude_parser_args.get(k, True)]) + if Session.check_min_api_version('2.9'): + task_defaults = dict( + [(self._prefix_args + k, v) for k, v in task_defaults.items() + if self._exclude_parser_args.get(k, True)]) + task_defaults_descriptions = dict( + [(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items() + if self._exclude_parser_args.get(k, True)]) + task_defaults_types = dict( + [(self._prefix_args + k, v) for k, v in task_defaults_types.items() + if self._exclude_parser_args.get(k, True)]) + # Store to task - self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions) + self._task.update_parameters( + task_defaults, + __parameters_descriptions=task_defaults_descriptions, + __parameters_types=task_defaults_types + ) @classmethod def _find_parser_action(cls, a_parser, name): @@ -176,9 +197,10 @@ class _Arguments(object): def copy_to_parser(self, parser, parsed_args): # Change to argparse prefix only - task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() - if k.startswith(self._prefix_args) and - self._exclude_parser_args.get(k[len(self._prefix_args):], True)]) + prefix = self._prefix_args if Session.check_min_api_version('2.9') else '' + task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() + if k.startswith(prefix) and + self._exclude_parser_args.get(k[len(prefix):], True)]) arg_parser_argeuments = {} for k, v in task_arguments.items(): # python2 unicode support @@ -319,15 +341,29 @@ class _Arguments(object): setattr(parsed_args, k, v) parser.set_defaults(**arg_parser_argeuments) - def copy_from_dict(self, dictionary, prefix=None): + def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None): # add dict prefix prefix = prefix # or self._prefix_dict if prefix: prefix = prefix.strip(self._prefix_sep) + self._prefix_sep + if descriptions: + descriptions = dict((prefix+k, v) for k, v in descriptions.items()) + if param_types: + param_types = dict((prefix+k, v) for k, v in param_types.items()) # this will only set the specific section - self._task.set_parameters(dictionary, __parameters_prefix=prefix) + self._task.set_parameters( + dictionary, + __parameters_prefix=prefix, + __parameters_descriptions=descriptions, + __parameters_types=param_types, + ) else: - self._task.update_parameters(dictionary) + self._task.update_parameters( + dictionary, + __parameters_prefix=prefix, + __parameters_descriptions=descriptions, + __parameters_types=param_types, + ) if not isinstance(dictionary, self._ProxyDictWrite): return self._ProxyDictWrite(self, prefix, **dictionary) return dictionary diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index fe71fd58..181babbd 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -5,6 +5,7 @@ import logging import os import sys import re +from copy import copy from enum import Enum from tempfile import gettempdir from multiprocessing import RLock @@ -785,17 +786,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if not Session.check_min_api_version('2.9'): return self._get_task_property('execution.parameters') - # API will makes sure we get old parameters under _legacy + # API will makes sure we get old parameters with type legacy on top level (instead of nested in General) parameters = dict() - hyperparams = self._get_task_property('hyperparams', default=dict()) - if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility: - for section in ('_legacy', ): - for key, section_param in hyperparams[section].items(): - parameters['{}'.format(key)] = section_param.value - else: + hyperparams = self._get_task_property('hyperparams') or {} + if not backwards_compatibility: for section in hyperparams: for key, section_param in hyperparams[section].items(): parameters['{}/{}'.format(section, key)] = section_param.value + else: + for section in hyperparams: + for key, section_param in hyperparams[section].items(): + if section_param.type == 'legacy' and section in (self._default_configuration_section_name, ): + parameters['{}'.format(key)] = section_param.value + else: + parameters['{}/{}'.format(section, key)] = section_param.value return parameters @@ -856,10 +860,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # if we have a specific prefix and we use hyperparameters, and we use set. # overwrite only the prefix, leave the rest as is. if not update and prefix: - parameters = dict(**(self.get_parameters() or {})) + parameters = copy(self.get_parameters() or {}) parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/')) elif update: - parameters = dict(**(self.get_parameters() or {})) + parameters = copy(self.get_parameters() or {}) else: parameters = dict() @@ -872,34 +876,40 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # build nested dict from flat parameters dict: org_hyperparams = self.data.hyperparams or {} hyperparams = dict() + # if the task is a legacy task, we should put everything back under General/key with legacy type + legacy_name = self._default_configuration_section_name + org_legacy_section = org_hyperparams.get(legacy_name, dict()) + + for k, v in parameters.items(): + # legacy variable + if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy': + section = hyperparams.get(legacy_name, dict()) + section[k] = copy(org_legacy_section[k]) + section[k].value = str(v) if v else v + description = descriptions.get(k) + if description: + section[k].description = description + hyperparams[legacy_name] = section + continue + + org_k = k + if '/' not in k: + k = '{}/{}'.format(self._default_configuration_section_name, k) + section_name, key = k.split('/', 1) + section = hyperparams.get(section_name, dict()) + org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()) + param_type = params_types[org_k] if org_k in params_types else org_param.type + if param_type and not isinstance(param_type, str): + param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type) + + section[key] = tasks.ParamsItem( + section=section_name, name=key, + value=str(v) if v else v, + description=descriptions[org_k] if org_k in descriptions else org_param.description, + type=param_type, + ) + hyperparams[section_name] = section - # if the task is a legacy task, we should put everything back under _legacy - if self.data.hyperparams and '_legacy' in self.data.hyperparams: - for k, v in parameters.items(): - section_name, key = '_legacy', k - section = hyperparams.get(section_name, dict()) - description = \ - descriptions.get(k) or \ - org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description - param_type = \ - params_types.get(k) or \ - org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type - section[key] = tasks.ParamsItem( - section=section_name, name=key, value=v, description=description, type=str(param_type)) - hyperparams[section_name] = section - else: - for k, v in parameters.items(): - org_k = k - if '/' not in k: - k = '{}/{}'.format(self._default_configuration_section_name, k) - section_name, key = k.split('/', 1) - section = hyperparams.get(section_name, dict()) - description = \ - descriptions.get(org_k) or \ - org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description - section[key] = tasks.ParamsItem\ - (section=section_name, name=key, value=v, description=description) - hyperparams[section_name] = section self._edit(hyperparams=hyperparams) else: execution = self.data.execution diff --git a/trains/binding/absl_bind.py b/trains/binding/absl_bind.py index 006eb964..8228699c 100644 --- a/trains/binding/absl_bind.py +++ b/trains/binding/absl_bind.py @@ -73,8 +73,8 @@ class PatchAbsl(object): # noinspection PyBroadException try: if param_name and flag: - param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value}, - prefix=_Arguments._prefix_tf_defines) + param_dict = PatchAbsl._task._arguments.copy_to_dict( + {param_name: flag.value}, prefix=_Arguments._prefix_tf_defines) flag.value = param_dict.get(param_name, flag.value) except Exception: pass @@ -82,7 +82,7 @@ class PatchAbsl(object): else: if flag and param_name: value = flag.value - PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}) + PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, ) ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs) return ret @@ -117,6 +117,20 @@ class PatchAbsl(object): else: # clear previous parameters parameters = dict([(k, FLAGS[k].value) for k in FLAGS]) - cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines) + # noinspection PyBroadException + try: + descriptions = dict([(k, FLAGS[k].help or None) for k in FLAGS]) + except Exception: + descriptions = None + # noinspection PyBroadException + try: + param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS]) + except Exception: + param_types = None + cls._task._arguments.copy_from_dict( + parameters, + prefix=_Arguments._prefix_tf_defines, + descriptions=descriptions, param_types=param_types, + ) except Exception: pass