Fix hyper-parameter legacy mode (type=='legacy')

Add type/description to TF_DEFINES
Cast hyper-parameters to string (if not None)
This commit is contained in:
allegroai 2020-08-08 12:48:23 +03:00
parent de61dbf54e
commit 65003a168a
3 changed files with 117 additions and 57 deletions

View File

@ -4,6 +4,7 @@ from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS
from copy import copy from copy import copy
from ...backend_api import Session
from ...utilities.args import call_original_argparser from ...utilities.args import call_original_argparser
@ -59,21 +60,24 @@ class _Arguments(object):
self._exclude_parser_args = excluded_args or {} self._exclude_parser_args = excluded_args or {}
def set_defaults(self, *dicts, **kwargs): def set_defaults(self, *dicts, **kwargs):
prefix = self._prefix_args if Session.check_min_api_version('2.9') else None
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs) self._task._set_parameters(*dicts, __parameters_prefix=prefix, **kwargs)
def add_argument(self, option_strings, type=None, default=None, help=None): def add_argument(self, option_strings, type=None, default=None, help=None):
if not option_strings: if not option_strings:
raise Exception('Expected at least one argument name (option string)') raise Exception('Expected at least one argument name (option string)')
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t') name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
name = self._prefix_args + name if Session.check_min_api_version('2.9'):
name = self._prefix_args + name
self._task.set_parameter(name=name, value=default, description=help) self._task.set_parameter(name=name, value=default, description=help)
def connect(self, parser): def connect(self, parser):
self._task.connect_argparse(parser) self._task.connect_argparse(parser)
@classmethod @classmethod
def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None): def _add_to_defaults(cls, a_parser, defaults, descriptions, arg_types,
a_args=None, a_namespace=None, a_parsed_args=None):
actions = [ actions = [
a for a in a_parser._actions a for a in a_parser._actions
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction) if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
@ -100,6 +104,8 @@ class _Arguments(object):
desc_ = {a.dest: a.help for a in actions} desc_ = {a.dest: a.help for a in actions}
descriptions.update(desc_) descriptions.update(desc_)
types_ = {a.dest: (a.type or None) for a in actions}
arg_types.update(types_)
full_args_dict = copy(defaults) full_args_dict = copy(defaults)
full_args_dict.update(args_dict) full_args_dict.update(args_dict)
@ -115,19 +121,23 @@ class _Arguments(object):
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or '' defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or ''
for choice in sub_parser.choices.values(): for choice in sub_parser.choices.values():
# recursively parse # recursively parse
defaults, descriptions = cls._add_to_defaults( defaults, descriptions, arg_types = cls._add_to_defaults(
a_parser=choice, a_parser=choice,
defaults=defaults, defaults=defaults,
descriptions=descriptions, descriptions=descriptions,
arg_types=arg_types,
a_parsed_args=a_parsed_args or full_args_dict a_parsed_args=a_parsed_args or full_args_dict
) )
return defaults, descriptions return defaults, descriptions, arg_types
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None): def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
task_defaults = {} task_defaults = {}
task_defaults_descriptions = {} task_defaults_descriptions = {}
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args) task_defaults_types = {}
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, task_defaults_types,
args, namespace, parsed_args)
# Make sure we didn't miss anything # Make sure we didn't miss anything
if parsed_args: if parsed_args:
@ -150,12 +160,23 @@ class _Arguments(object):
del task_defaults[k] del task_defaults[k]
# Skip excluded arguments, Add prefix. # Skip excluded arguments, Add prefix.
task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() if Session.check_min_api_version('2.9'):
if self._exclude_parser_args.get(k, True)]) task_defaults = dict(
task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items() [(self._prefix_args + k, v) for k, v in task_defaults.items()
if self._exclude_parser_args.get(k, True)]) if self._exclude_parser_args.get(k, True)])
task_defaults_descriptions = dict(
[(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
if self._exclude_parser_args.get(k, True)])
task_defaults_types = dict(
[(self._prefix_args + k, v) for k, v in task_defaults_types.items()
if self._exclude_parser_args.get(k, True)])
# Store to task # Store to task
self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions) self._task.update_parameters(
task_defaults,
__parameters_descriptions=task_defaults_descriptions,
__parameters_types=task_defaults_types
)
@classmethod @classmethod
def _find_parser_action(cls, a_parser, name): def _find_parser_action(cls, a_parser, name):
@ -176,9 +197,10 @@ class _Arguments(object):
def copy_to_parser(self, parser, parsed_args): def copy_to_parser(self, parser, parsed_args):
# Change to argparse prefix only # Change to argparse prefix only
task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() prefix = self._prefix_args if Session.check_min_api_version('2.9') else ''
if k.startswith(self._prefix_args) and task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
self._exclude_parser_args.get(k[len(self._prefix_args):], True)]) if k.startswith(prefix) and
self._exclude_parser_args.get(k[len(prefix):], True)])
arg_parser_argeuments = {} arg_parser_argeuments = {}
for k, v in task_arguments.items(): for k, v in task_arguments.items():
# python2 unicode support # python2 unicode support
@ -319,15 +341,29 @@ class _Arguments(object):
setattr(parsed_args, k, v) setattr(parsed_args, k, v)
parser.set_defaults(**arg_parser_argeuments) parser.set_defaults(**arg_parser_argeuments)
def copy_from_dict(self, dictionary, prefix=None): def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None):
# add dict prefix # add dict prefix
prefix = prefix # or self._prefix_dict prefix = prefix # or self._prefix_dict
if prefix: if prefix:
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
if descriptions:
descriptions = dict((prefix+k, v) for k, v in descriptions.items())
if param_types:
param_types = dict((prefix+k, v) for k, v in param_types.items())
# this will only set the specific section # this will only set the specific section
self._task.set_parameters(dictionary, __parameters_prefix=prefix) self._task.set_parameters(
dictionary,
__parameters_prefix=prefix,
__parameters_descriptions=descriptions,
__parameters_types=param_types,
)
else: else:
self._task.update_parameters(dictionary) self._task.update_parameters(
dictionary,
__parameters_prefix=prefix,
__parameters_descriptions=descriptions,
__parameters_types=param_types,
)
if not isinstance(dictionary, self._ProxyDictWrite): if not isinstance(dictionary, self._ProxyDictWrite):
return self._ProxyDictWrite(self, prefix, **dictionary) return self._ProxyDictWrite(self, prefix, **dictionary)
return dictionary return dictionary

View File

@ -5,6 +5,7 @@ import logging
import os import os
import sys import sys
import re import re
from copy import copy
from enum import Enum from enum import Enum
from tempfile import gettempdir from tempfile import gettempdir
from multiprocessing import RLock from multiprocessing import RLock
@ -785,17 +786,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
return self._get_task_property('execution.parameters') return self._get_task_property('execution.parameters')
# API will makes sure we get old parameters under _legacy # API will makes sure we get old parameters with type legacy on top level (instead of nested in General)
parameters = dict() parameters = dict()
hyperparams = self._get_task_property('hyperparams', default=dict()) hyperparams = self._get_task_property('hyperparams') or {}
if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility: if not backwards_compatibility:
for section in ('_legacy', ):
for key, section_param in hyperparams[section].items():
parameters['{}'.format(key)] = section_param.value
else:
for section in hyperparams: for section in hyperparams:
for key, section_param in hyperparams[section].items(): for key, section_param in hyperparams[section].items():
parameters['{}/{}'.format(section, key)] = section_param.value parameters['{}/{}'.format(section, key)] = section_param.value
else:
for section in hyperparams:
for key, section_param in hyperparams[section].items():
if section_param.type == 'legacy' and section in (self._default_configuration_section_name, ):
parameters['{}'.format(key)] = section_param.value
else:
parameters['{}/{}'.format(section, key)] = section_param.value
return parameters return parameters
@ -856,10 +860,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# if we have a specific prefix and we use hyperparameters, and we use set. # if we have a specific prefix and we use hyperparameters, and we use set.
# overwrite only the prefix, leave the rest as is. # overwrite only the prefix, leave the rest as is.
if not update and prefix: if not update and prefix:
parameters = dict(**(self.get_parameters() or {})) parameters = copy(self.get_parameters() or {})
parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/')) parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/'))
elif update: elif update:
parameters = dict(**(self.get_parameters() or {})) parameters = copy(self.get_parameters() or {})
else: else:
parameters = dict() parameters = dict()
@ -872,34 +876,40 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# build nested dict from flat parameters dict: # build nested dict from flat parameters dict:
org_hyperparams = self.data.hyperparams or {} org_hyperparams = self.data.hyperparams or {}
hyperparams = dict() hyperparams = dict()
# if the task is a legacy task, we should put everything back under General/key with legacy type
legacy_name = self._default_configuration_section_name
org_legacy_section = org_hyperparams.get(legacy_name, dict())
for k, v in parameters.items():
# legacy variable
if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy':
section = hyperparams.get(legacy_name, dict())
section[k] = copy(org_legacy_section[k])
section[k].value = str(v) if v else v
description = descriptions.get(k)
if description:
section[k].description = description
hyperparams[legacy_name] = section
continue
org_k = k
if '/' not in k:
k = '{}/{}'.format(self._default_configuration_section_name, k)
section_name, key = k.split('/', 1)
section = hyperparams.get(section_name, dict())
org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem())
param_type = params_types[org_k] if org_k in params_types else org_param.type
if param_type and not isinstance(param_type, str):
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
section[key] = tasks.ParamsItem(
section=section_name, name=key,
value=str(v) if v else v,
description=descriptions[org_k] if org_k in descriptions else org_param.description,
type=param_type,
)
hyperparams[section_name] = section
# if the task is a legacy task, we should put everything back under _legacy
if self.data.hyperparams and '_legacy' in self.data.hyperparams:
for k, v in parameters.items():
section_name, key = '_legacy', k
section = hyperparams.get(section_name, dict())
description = \
descriptions.get(k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
param_type = \
params_types.get(k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type
section[key] = tasks.ParamsItem(
section=section_name, name=key, value=v, description=description, type=str(param_type))
hyperparams[section_name] = section
else:
for k, v in parameters.items():
org_k = k
if '/' not in k:
k = '{}/{}'.format(self._default_configuration_section_name, k)
section_name, key = k.split('/', 1)
section = hyperparams.get(section_name, dict())
description = \
descriptions.get(org_k) or \
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
section[key] = tasks.ParamsItem\
(section=section_name, name=key, value=v, description=description)
hyperparams[section_name] = section
self._edit(hyperparams=hyperparams) self._edit(hyperparams=hyperparams)
else: else:
execution = self.data.execution execution = self.data.execution

View File

@ -73,8 +73,8 @@ class PatchAbsl(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if param_name and flag: if param_name and flag:
param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value}, param_dict = PatchAbsl._task._arguments.copy_to_dict(
prefix=_Arguments._prefix_tf_defines) {param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
flag.value = param_dict.get(param_name, flag.value) flag.value = param_dict.get(param_name, flag.value)
except Exception: except Exception:
pass pass
@ -82,7 +82,7 @@ class PatchAbsl(object):
else: else:
if flag and param_name: if flag and param_name:
value = flag.value value = flag.value
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}) PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, )
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs) ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
return ret return ret
@ -117,6 +117,20 @@ class PatchAbsl(object):
else: else:
# clear previous parameters # clear previous parameters
parameters = dict([(k, FLAGS[k].value) for k in FLAGS]) parameters = dict([(k, FLAGS[k].value) for k in FLAGS])
cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines) # noinspection PyBroadException
try:
descriptions = dict([(k, FLAGS[k].help or None) for k in FLAGS])
except Exception:
descriptions = None
# noinspection PyBroadException
try:
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
except Exception:
param_types = None
cls._task._arguments.copy_from_dict(
parameters,
prefix=_Arguments._prefix_tf_defines,
descriptions=descriptions, param_types=param_types,
)
except Exception: except Exception:
pass pass