Fix hyper-parameter legacy mode (type=='legacy')

Add type/description to TF_DEFINES
Cast hyper-parameters to string (if not None)
This commit is contained in:
allegroai 2020-08-08 12:48:23 +03:00
parent de61dbf54e
commit 65003a168a
3 changed files with 117 additions and 57 deletions

View File

@ -4,6 +4,7 @@ from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS
from copy import copy
from ...backend_api import Session
from ...utilities.args import call_original_argparser
@ -59,13 +60,15 @@ class _Arguments(object):
self._exclude_parser_args = excluded_args or {}
def set_defaults(self, *dicts, **kwargs):
prefix = self._prefix_args if Session.check_min_api_version('2.9') else None
# noinspection PyProtectedMember
self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs)
self._task._set_parameters(*dicts, __parameters_prefix=prefix, **kwargs)
def add_argument(self, option_strings, type=None, default=None, help=None):
if not option_strings:
raise Exception('Expected at least one argument name (option string)')
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
if Session.check_min_api_version('2.9'):
name = self._prefix_args + name
self._task.set_parameter(name=name, value=default, description=help)
@ -73,7 +76,8 @@ class _Arguments(object):
self._task.connect_argparse(parser)
@classmethod
def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None):
def _add_to_defaults(cls, a_parser, defaults, descriptions, arg_types,
a_args=None, a_namespace=None, a_parsed_args=None):
actions = [
a for a in a_parser._actions
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
@ -100,6 +104,8 @@ class _Arguments(object):
desc_ = {a.dest: a.help for a in actions}
descriptions.update(desc_)
types_ = {a.dest: (a.type or None) for a in actions}
arg_types.update(types_)
full_args_dict = copy(defaults)
full_args_dict.update(args_dict)
@ -115,19 +121,23 @@ class _Arguments(object):
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or ''
for choice in sub_parser.choices.values():
# recursively parse
defaults, descriptions = cls._add_to_defaults(
defaults, descriptions, arg_types = cls._add_to_defaults(
a_parser=choice,
defaults=defaults,
descriptions=descriptions,
arg_types=arg_types,
a_parsed_args=a_parsed_args or full_args_dict
)
return defaults, descriptions
return defaults, descriptions, arg_types
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
task_defaults = {}
task_defaults_descriptions = {}
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args)
task_defaults_types = {}
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, task_defaults_types,
args, namespace, parsed_args)
# Make sure we didn't miss anything
if parsed_args:
@ -150,12 +160,23 @@ class _Arguments(object):
del task_defaults[k]
# Skip excluded arguments, Add prefix.
task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()
if Session.check_min_api_version('2.9'):
task_defaults = dict(
[(self._prefix_args + k, v) for k, v in task_defaults.items()
if self._exclude_parser_args.get(k, True)])
task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
task_defaults_descriptions = dict(
[(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
if self._exclude_parser_args.get(k, True)])
task_defaults_types = dict(
[(self._prefix_args + k, v) for k, v in task_defaults_types.items()
if self._exclude_parser_args.get(k, True)])
# Store to task
self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions)
self._task.update_parameters(
task_defaults,
__parameters_descriptions=task_defaults_descriptions,
__parameters_types=task_defaults_types
)
@classmethod
def _find_parser_action(cls, a_parser, name):
@ -176,9 +197,10 @@ class _Arguments(object):
def copy_to_parser(self, parser, parsed_args):
# Change to argparse prefix only
task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
if k.startswith(self._prefix_args) and
self._exclude_parser_args.get(k[len(self._prefix_args):], True)])
prefix = self._prefix_args if Session.check_min_api_version('2.9') else ''
task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
if k.startswith(prefix) and
self._exclude_parser_args.get(k[len(prefix):], True)])
arg_parser_argeuments = {}
for k, v in task_arguments.items():
# python2 unicode support
@ -319,15 +341,29 @@ class _Arguments(object):
setattr(parsed_args, k, v)
parser.set_defaults(**arg_parser_argeuments)
def copy_from_dict(self, dictionary, prefix=None):
def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None):
# add dict prefix
prefix = prefix # or self._prefix_dict
if prefix:
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
if descriptions:
descriptions = dict((prefix+k, v) for k, v in descriptions.items())
if param_types:
param_types = dict((prefix+k, v) for k, v in param_types.items())
# this will only set the specific section
self._task.set_parameters(dictionary, __parameters_prefix=prefix)
self._task.set_parameters(
dictionary,
__parameters_prefix=prefix,
__parameters_descriptions=descriptions,
__parameters_types=param_types,
)
else:
self._task.update_parameters(dictionary)
self._task.update_parameters(
dictionary,
__parameters_prefix=prefix,
__parameters_descriptions=descriptions,
__parameters_types=param_types,
)
if not isinstance(dictionary, self._ProxyDictWrite):
return self._ProxyDictWrite(self, prefix, **dictionary)
return dictionary

View File

@ -5,6 +5,7 @@ import logging
import os
import sys
import re
from copy import copy
from enum import Enum
from tempfile import gettempdir
from multiprocessing import RLock
@ -785,16 +786,19 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not Session.check_min_api_version('2.9'):
return self._get_task_property('execution.parameters')
# API will makes sure we get old parameters under _legacy
# API will makes sure we get old parameters with type legacy on top level (instead of nested in General)
parameters = dict()
hyperparams = self._get_task_property('hyperparams', default=dict())
if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility:
for section in ('_legacy', ):
hyperparams = self._get_task_property('hyperparams') or {}
if not backwards_compatibility:
for section in hyperparams:
for key, section_param in hyperparams[section].items():
parameters['{}'.format(key)] = section_param.value
parameters['{}/{}'.format(section, key)] = section_param.value
else:
for section in hyperparams:
for key, section_param in hyperparams[section].items():
if section_param.type == 'legacy' and section in (self._default_configuration_section_name, ):
parameters['{}'.format(key)] = section_param.value
else:
parameters['{}/{}'.format(section, key)] = section_param.value
return parameters
@ -856,10 +860,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# if we have a specific prefix and we use hyperparameters, and we use set.
# overwrite only the prefix, leave the rest as is.
if not update and prefix:
parameters = dict(**(self.get_parameters() or {}))
parameters = copy(self.get_parameters() or {})
parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/'))
elif update:
parameters = dict(**(self.get_parameters() or {}))
parameters = copy(self.get_parameters() or {})
else:
parameters = dict()
@ -872,34 +876,40 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# build nested dict from flat parameters dict:
org_hyperparams = self.data.hyperparams or {}
hyperparams = dict()
# if the task is a legacy task, we should put everything back under General/key with legacy type
legacy_name = self._default_configuration_section_name
org_legacy_section = org_hyperparams.get(legacy_name, dict())
# if the task is a legacy task, we should put everything back under _legacy
if self.data.hyperparams and '_legacy' in self.data.hyperparams:
for k, v in parameters.items():
section_name, key = '_legacy', k
section = hyperparams.get(section_name, dict())
description = \
descriptions.get(k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
param_type = \
params_types.get(k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type
section[key] = tasks.ParamsItem(
section=section_name, name=key, value=v, description=description, type=str(param_type))
hyperparams[section_name] = section
else:
for k, v in parameters.items():
# legacy variable
if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy':
section = hyperparams.get(legacy_name, dict())
section[k] = copy(org_legacy_section[k])
section[k].value = str(v) if v else v
description = descriptions.get(k)
if description:
section[k].description = description
hyperparams[legacy_name] = section
continue
org_k = k
if '/' not in k:
k = '{}/{}'.format(self._default_configuration_section_name, k)
section_name, key = k.split('/', 1)
section = hyperparams.get(section_name, dict())
description = \
descriptions.get(org_k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
section[key] = tasks.ParamsItem\
(section=section_name, name=key, value=v, description=description)
org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem())
param_type = params_types[org_k] if org_k in params_types else org_param.type
if param_type and not isinstance(param_type, str):
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
section[key] = tasks.ParamsItem(
section=section_name, name=key,
value=str(v) if v else v,
description=descriptions[org_k] if org_k in descriptions else org_param.description,
type=param_type,
)
hyperparams[section_name] = section
self._edit(hyperparams=hyperparams)
else:
execution = self.data.execution

View File

@ -73,8 +73,8 @@ class PatchAbsl(object):
# noinspection PyBroadException
try:
if param_name and flag:
param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value},
prefix=_Arguments._prefix_tf_defines)
param_dict = PatchAbsl._task._arguments.copy_to_dict(
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
flag.value = param_dict.get(param_name, flag.value)
except Exception:
pass
@ -82,7 +82,7 @@ class PatchAbsl(object):
else:
if flag and param_name:
value = flag.value
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value})
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, )
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
return ret
@ -117,6 +117,20 @@ class PatchAbsl(object):
else:
# clear previous parameters
parameters = dict([(k, FLAGS[k].value) for k in FLAGS])
cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines)
# noinspection PyBroadException
try:
descriptions = dict([(k, FLAGS[k].help or None) for k in FLAGS])
except Exception:
descriptions = None
# noinspection PyBroadException
try:
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
except Exception:
param_types = None
cls._task._arguments.copy_from_dict(
parameters,
prefix=_Arguments._prefix_tf_defines,
descriptions=descriptions, param_types=param_types,
)
except Exception:
pass