mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix hyper-parameter legacy mode (type=='legacy')
Add type/description to TF_DEFINES Cast hyper-parameters to string (if not None)
This commit is contained in:
parent
de61dbf54e
commit
65003a168a
@ -4,6 +4,7 @@ from six import PY2
|
||||
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS
|
||||
from copy import copy
|
||||
|
||||
from ...backend_api import Session
|
||||
from ...utilities.args import call_original_argparser
|
||||
|
||||
|
||||
@ -59,13 +60,15 @@ class _Arguments(object):
|
||||
self._exclude_parser_args = excluded_args or {}
|
||||
|
||||
def set_defaults(self, *dicts, **kwargs):
|
||||
prefix = self._prefix_args if Session.check_min_api_version('2.9') else None
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs)
|
||||
self._task._set_parameters(*dicts, __parameters_prefix=prefix, **kwargs)
|
||||
|
||||
def add_argument(self, option_strings, type=None, default=None, help=None):
|
||||
if not option_strings:
|
||||
raise Exception('Expected at least one argument name (option string)')
|
||||
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
|
||||
if Session.check_min_api_version('2.9'):
|
||||
name = self._prefix_args + name
|
||||
self._task.set_parameter(name=name, value=default, description=help)
|
||||
|
||||
@ -73,7 +76,8 @@ class _Arguments(object):
|
||||
self._task.connect_argparse(parser)
|
||||
|
||||
@classmethod
|
||||
def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None):
|
||||
def _add_to_defaults(cls, a_parser, defaults, descriptions, arg_types,
|
||||
a_args=None, a_namespace=None, a_parsed_args=None):
|
||||
actions = [
|
||||
a for a in a_parser._actions
|
||||
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
|
||||
@ -100,6 +104,8 @@ class _Arguments(object):
|
||||
|
||||
desc_ = {a.dest: a.help for a in actions}
|
||||
descriptions.update(desc_)
|
||||
types_ = {a.dest: (a.type or None) for a in actions}
|
||||
arg_types.update(types_)
|
||||
|
||||
full_args_dict = copy(defaults)
|
||||
full_args_dict.update(args_dict)
|
||||
@ -115,19 +121,23 @@ class _Arguments(object):
|
||||
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or ''
|
||||
for choice in sub_parser.choices.values():
|
||||
# recursively parse
|
||||
defaults, descriptions = cls._add_to_defaults(
|
||||
defaults, descriptions, arg_types = cls._add_to_defaults(
|
||||
a_parser=choice,
|
||||
defaults=defaults,
|
||||
descriptions=descriptions,
|
||||
arg_types=arg_types,
|
||||
a_parsed_args=a_parsed_args or full_args_dict
|
||||
)
|
||||
|
||||
return defaults, descriptions
|
||||
return defaults, descriptions, arg_types
|
||||
|
||||
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
||||
task_defaults = {}
|
||||
task_defaults_descriptions = {}
|
||||
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args)
|
||||
task_defaults_types = {}
|
||||
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, task_defaults_types,
|
||||
args, namespace, parsed_args)
|
||||
|
||||
|
||||
# Make sure we didn't miss anything
|
||||
if parsed_args:
|
||||
@ -150,12 +160,23 @@ class _Arguments(object):
|
||||
del task_defaults[k]
|
||||
|
||||
# Skip excluded arguments, Add prefix.
|
||||
task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()
|
||||
if Session.check_min_api_version('2.9'):
|
||||
task_defaults = dict(
|
||||
[(self._prefix_args + k, v) for k, v in task_defaults.items()
|
||||
if self._exclude_parser_args.get(k, True)])
|
||||
task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
|
||||
task_defaults_descriptions = dict(
|
||||
[(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
|
||||
if self._exclude_parser_args.get(k, True)])
|
||||
task_defaults_types = dict(
|
||||
[(self._prefix_args + k, v) for k, v in task_defaults_types.items()
|
||||
if self._exclude_parser_args.get(k, True)])
|
||||
|
||||
# Store to task
|
||||
self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions)
|
||||
self._task.update_parameters(
|
||||
task_defaults,
|
||||
__parameters_descriptions=task_defaults_descriptions,
|
||||
__parameters_types=task_defaults_types
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _find_parser_action(cls, a_parser, name):
|
||||
@ -176,9 +197,10 @@ class _Arguments(object):
|
||||
|
||||
def copy_to_parser(self, parser, parsed_args):
|
||||
# Change to argparse prefix only
|
||||
task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
|
||||
if k.startswith(self._prefix_args) and
|
||||
self._exclude_parser_args.get(k[len(self._prefix_args):], True)])
|
||||
prefix = self._prefix_args if Session.check_min_api_version('2.9') else ''
|
||||
task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
||||
if k.startswith(prefix) and
|
||||
self._exclude_parser_args.get(k[len(prefix):], True)])
|
||||
arg_parser_argeuments = {}
|
||||
for k, v in task_arguments.items():
|
||||
# python2 unicode support
|
||||
@ -319,15 +341,29 @@ class _Arguments(object):
|
||||
setattr(parsed_args, k, v)
|
||||
parser.set_defaults(**arg_parser_argeuments)
|
||||
|
||||
def copy_from_dict(self, dictionary, prefix=None):
|
||||
def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None):
|
||||
# add dict prefix
|
||||
prefix = prefix # or self._prefix_dict
|
||||
if prefix:
|
||||
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
|
||||
if descriptions:
|
||||
descriptions = dict((prefix+k, v) for k, v in descriptions.items())
|
||||
if param_types:
|
||||
param_types = dict((prefix+k, v) for k, v in param_types.items())
|
||||
# this will only set the specific section
|
||||
self._task.set_parameters(dictionary, __parameters_prefix=prefix)
|
||||
self._task.set_parameters(
|
||||
dictionary,
|
||||
__parameters_prefix=prefix,
|
||||
__parameters_descriptions=descriptions,
|
||||
__parameters_types=param_types,
|
||||
)
|
||||
else:
|
||||
self._task.update_parameters(dictionary)
|
||||
self._task.update_parameters(
|
||||
dictionary,
|
||||
__parameters_prefix=prefix,
|
||||
__parameters_descriptions=descriptions,
|
||||
__parameters_types=param_types,
|
||||
)
|
||||
if not isinstance(dictionary, self._ProxyDictWrite):
|
||||
return self._ProxyDictWrite(self, prefix, **dictionary)
|
||||
return dictionary
|
||||
|
@ -5,6 +5,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from tempfile import gettempdir
|
||||
from multiprocessing import RLock
|
||||
@ -785,16 +786,19 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if not Session.check_min_api_version('2.9'):
|
||||
return self._get_task_property('execution.parameters')
|
||||
|
||||
# API will makes sure we get old parameters under _legacy
|
||||
# API will makes sure we get old parameters with type legacy on top level (instead of nested in General)
|
||||
parameters = dict()
|
||||
hyperparams = self._get_task_property('hyperparams', default=dict())
|
||||
if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility:
|
||||
for section in ('_legacy', ):
|
||||
hyperparams = self._get_task_property('hyperparams') or {}
|
||||
if not backwards_compatibility:
|
||||
for section in hyperparams:
|
||||
for key, section_param in hyperparams[section].items():
|
||||
parameters['{}'.format(key)] = section_param.value
|
||||
parameters['{}/{}'.format(section, key)] = section_param.value
|
||||
else:
|
||||
for section in hyperparams:
|
||||
for key, section_param in hyperparams[section].items():
|
||||
if section_param.type == 'legacy' and section in (self._default_configuration_section_name, ):
|
||||
parameters['{}'.format(key)] = section_param.value
|
||||
else:
|
||||
parameters['{}/{}'.format(section, key)] = section_param.value
|
||||
|
||||
return parameters
|
||||
@ -856,10 +860,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# if we have a specific prefix and we use hyperparameters, and we use set.
|
||||
# overwrite only the prefix, leave the rest as is.
|
||||
if not update and prefix:
|
||||
parameters = dict(**(self.get_parameters() or {}))
|
||||
parameters = copy(self.get_parameters() or {})
|
||||
parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/'))
|
||||
elif update:
|
||||
parameters = dict(**(self.get_parameters() or {}))
|
||||
parameters = copy(self.get_parameters() or {})
|
||||
else:
|
||||
parameters = dict()
|
||||
|
||||
@ -872,34 +876,40 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# build nested dict from flat parameters dict:
|
||||
org_hyperparams = self.data.hyperparams or {}
|
||||
hyperparams = dict()
|
||||
# if the task is a legacy task, we should put everything back under General/key with legacy type
|
||||
legacy_name = self._default_configuration_section_name
|
||||
org_legacy_section = org_hyperparams.get(legacy_name, dict())
|
||||
|
||||
# if the task is a legacy task, we should put everything back under _legacy
|
||||
if self.data.hyperparams and '_legacy' in self.data.hyperparams:
|
||||
for k, v in parameters.items():
|
||||
section_name, key = '_legacy', k
|
||||
section = hyperparams.get(section_name, dict())
|
||||
description = \
|
||||
descriptions.get(k) or \
|
||||
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
|
||||
param_type = \
|
||||
params_types.get(k) or \
|
||||
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type
|
||||
section[key] = tasks.ParamsItem(
|
||||
section=section_name, name=key, value=v, description=description, type=str(param_type))
|
||||
hyperparams[section_name] = section
|
||||
else:
|
||||
for k, v in parameters.items():
|
||||
# legacy variable
|
||||
if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy':
|
||||
section = hyperparams.get(legacy_name, dict())
|
||||
section[k] = copy(org_legacy_section[k])
|
||||
section[k].value = str(v) if v else v
|
||||
description = descriptions.get(k)
|
||||
if description:
|
||||
section[k].description = description
|
||||
hyperparams[legacy_name] = section
|
||||
continue
|
||||
|
||||
org_k = k
|
||||
if '/' not in k:
|
||||
k = '{}/{}'.format(self._default_configuration_section_name, k)
|
||||
section_name, key = k.split('/', 1)
|
||||
section = hyperparams.get(section_name, dict())
|
||||
description = \
|
||||
descriptions.get(org_k) or \
|
||||
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
|
||||
section[key] = tasks.ParamsItem\
|
||||
(section=section_name, name=key, value=v, description=description)
|
||||
org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem())
|
||||
param_type = params_types[org_k] if org_k in params_types else org_param.type
|
||||
if param_type and not isinstance(param_type, str):
|
||||
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
|
||||
|
||||
section[key] = tasks.ParamsItem(
|
||||
section=section_name, name=key,
|
||||
value=str(v) if v else v,
|
||||
description=descriptions[org_k] if org_k in descriptions else org_param.description,
|
||||
type=param_type,
|
||||
)
|
||||
hyperparams[section_name] = section
|
||||
|
||||
self._edit(hyperparams=hyperparams)
|
||||
else:
|
||||
execution = self.data.execution
|
||||
|
@ -73,8 +73,8 @@ class PatchAbsl(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if param_name and flag:
|
||||
param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value},
|
||||
prefix=_Arguments._prefix_tf_defines)
|
||||
param_dict = PatchAbsl._task._arguments.copy_to_dict(
|
||||
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
|
||||
flag.value = param_dict.get(param_name, flag.value)
|
||||
except Exception:
|
||||
pass
|
||||
@ -82,7 +82,7 @@ class PatchAbsl(object):
|
||||
else:
|
||||
if flag and param_name:
|
||||
value = flag.value
|
||||
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value})
|
||||
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, )
|
||||
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
@ -117,6 +117,20 @@ class PatchAbsl(object):
|
||||
else:
|
||||
# clear previous parameters
|
||||
parameters = dict([(k, FLAGS[k].value) for k in FLAGS])
|
||||
cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
descriptions = dict([(k, FLAGS[k].help or None) for k in FLAGS])
|
||||
except Exception:
|
||||
descriptions = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
|
||||
except Exception:
|
||||
param_types = None
|
||||
cls._task._arguments.copy_from_dict(
|
||||
parameters,
|
||||
prefix=_Arguments._prefix_tf_defines,
|
||||
descriptions=descriptions, param_types=param_types,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user