mirror of
https://github.com/clearml/clearml
synced 2025-03-12 14:48:30 +00:00
Add multi configuration section support (hyperparams and configurations)
Support setting offline mode API version using TRAINS_OFFLINE_MODE env var
This commit is contained in:
parent
6d4e85de0a
commit
e378de1e41
@ -64,7 +64,11 @@ class DataModel(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _to_base_type(cls, value):
|
def _to_base_type(cls, value):
|
||||||
if isinstance(value, DataModel):
|
if isinstance(value, dict):
|
||||||
|
# Note: this should come before DataModel to handle data models that are simply a dict
|
||||||
|
# (and thus are not expected to have additional named properties)
|
||||||
|
return {k: cls._to_base_type(v) for k, v in value.items()}
|
||||||
|
elif isinstance(value, DataModel):
|
||||||
return value.to_dict()
|
return value.to_dict()
|
||||||
elif isinstance(value, enum.Enum):
|
elif isinstance(value, enum.Enum):
|
||||||
return value.value
|
return value.value
|
||||||
|
@ -62,6 +62,7 @@ class Session(TokenManager):
|
|||||||
default_files = "https://demofiles.trains.allegro.ai"
|
default_files = "https://demofiles.trains.allegro.ai"
|
||||||
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||||
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||||
|
force_max_api_version = None
|
||||||
|
|
||||||
# TODO: add requests.codes.gateway_timeout once we support async commits
|
# TODO: add requests.codes.gateway_timeout once we support async commits
|
||||||
_retry_codes = [
|
_retry_codes = [
|
||||||
@ -182,6 +183,9 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
self.__class__._sessions_created += 1
|
self.__class__._sessions_created += 1
|
||||||
|
|
||||||
|
if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version):
|
||||||
|
Session.api_version = str(self.force_max_api_version)
|
||||||
|
|
||||||
def _send_request(
|
def _send_request(
|
||||||
self,
|
self,
|
||||||
service,
|
service,
|
||||||
@ -539,6 +543,16 @@ class Session(TokenManager):
|
|||||||
# If no session was created, create a default one, in order to get the backend api version.
|
# If no session was created, create a default one, in order to get the backend api version.
|
||||||
if cls._sessions_created <= 0:
|
if cls._sessions_created <= 0:
|
||||||
if cls._offline_mode:
|
if cls._offline_mode:
|
||||||
|
# allow to change the offline mode version by setting ENV_OFFLINE_MODE to the required API version
|
||||||
|
if cls.api_version != cls._offline_default_version:
|
||||||
|
offline_api = ENV_OFFLINE_MODE.get(converter=lambda x: x)
|
||||||
|
if offline_api:
|
||||||
|
try:
|
||||||
|
# check cast to float, but leave original str if we pass it.
|
||||||
|
float(offline_api)
|
||||||
|
cls._offline_default_version = str(offline_api)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
cls.api_version = cls._offline_default_version
|
cls.api_version = cls._offline_default_version
|
||||||
else:
|
else:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
|
@ -10,32 +10,35 @@ from ...utilities.args import call_original_argparser
|
|||||||
class _Arguments(object):
|
class _Arguments(object):
|
||||||
_prefix_sep = '/'
|
_prefix_sep = '/'
|
||||||
# TODO: separate dict and argparse after we add UI support
|
# TODO: separate dict and argparse after we add UI support
|
||||||
_prefix_dict = 'dict' + _prefix_sep
|
_prefix_args = 'Args' + _prefix_sep
|
||||||
_prefix_args = 'argparse' + _prefix_sep
|
|
||||||
_prefix_tf_defines = 'TF_DEFINE' + _prefix_sep
|
_prefix_tf_defines = 'TF_DEFINE' + _prefix_sep
|
||||||
|
|
||||||
class _ProxyDictWrite(dict):
|
class _ProxyDictWrite(dict):
|
||||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||||
|
|
||||||
def __init__(self, arguments, *args, **kwargs):
|
def __init__(self, __arguments, __section_name, *args, **kwargs):
|
||||||
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
|
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
|
||||||
self._arguments = arguments
|
self._arguments = __arguments
|
||||||
|
self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \
|
||||||
|
if __section_name else None
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
super(_Arguments._ProxyDictWrite, self).__setitem__(key, value)
|
super(_Arguments._ProxyDictWrite, self).__setitem__(key, value)
|
||||||
if self._arguments:
|
if self._arguments:
|
||||||
self._arguments.copy_from_dict(self)
|
self._arguments.copy_from_dict(self, prefix=self._section_name)
|
||||||
|
|
||||||
class _ProxyDictReadOnly(dict):
|
class _ProxyDictReadOnly(dict):
|
||||||
""" Dictionary wrapper that prevents modifications to the dictionary """
|
""" Dictionary wrapper that prevents modifications to the dictionary """
|
||||||
|
|
||||||
def __init__(self, arguments, *args, **kwargs):
|
def __init__(self, __arguments, __section_name, *args, **kwargs):
|
||||||
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
|
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
|
||||||
self._arguments = arguments
|
self._arguments = __arguments
|
||||||
|
self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \
|
||||||
|
if __section_name else None
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
if self._arguments:
|
if self._arguments:
|
||||||
param_dict = self._arguments.copy_to_dict({key: value})
|
param_dict = self._arguments.copy_to_dict({key: value}, prefix=self._section_name)
|
||||||
value = param_dict.get(key, value)
|
value = param_dict.get(key, value)
|
||||||
super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value)
|
super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value)
|
||||||
|
|
||||||
@ -45,24 +48,32 @@ class _Arguments(object):
|
|||||||
self._exclude_parser_args = {}
|
self._exclude_parser_args = {}
|
||||||
|
|
||||||
def exclude_parser_args(self, excluded_args):
|
def exclude_parser_args(self, excluded_args):
|
||||||
|
"""
|
||||||
|
You can use a dictionary for fined grained control of connected
|
||||||
|
arguments. The dictionary keys are argparse variable names and the values are booleans.
|
||||||
|
The ``False`` value excludes the specified argument from the Task's parameter section.
|
||||||
|
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
||||||
|
|
||||||
|
:param excluded_args: dict
|
||||||
|
"""
|
||||||
self._exclude_parser_args = excluded_args or {}
|
self._exclude_parser_args = excluded_args or {}
|
||||||
|
|
||||||
def set_defaults(self, *dicts, **kwargs):
|
def set_defaults(self, *dicts, **kwargs):
|
||||||
self._task.set_parameters(*dicts, **kwargs)
|
# noinspection PyProtectedMember
|
||||||
|
self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs)
|
||||||
|
|
||||||
def add_argument(self, option_strings, type=None, default=None, help=None):
|
def add_argument(self, option_strings, type=None, default=None, help=None):
|
||||||
if not option_strings:
|
if not option_strings:
|
||||||
raise Exception('Expected at least one argument name (option string)')
|
raise Exception('Expected at least one argument name (option string)')
|
||||||
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
|
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
|
||||||
# TODO: add argparse prefix
|
name = self._prefix_args + name
|
||||||
# name = self._prefix_args + name
|
|
||||||
self._task.set_parameter(name=name, value=default, description=help)
|
self._task.set_parameter(name=name, value=default, description=help)
|
||||||
|
|
||||||
def connect(self, parser):
|
def connect(self, parser):
|
||||||
self._task.connect_argparse(parser)
|
self._task.connect_argparse(parser)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None):
|
def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None):
|
||||||
actions = [
|
actions = [
|
||||||
a for a in a_parser._actions
|
a for a in a_parser._actions
|
||||||
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
|
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
|
||||||
@ -87,6 +98,9 @@ class _Arguments(object):
|
|||||||
for a in actions
|
for a in actions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
desc_ = {a.dest: a.help for a in actions}
|
||||||
|
descriptions.update(desc_)
|
||||||
|
|
||||||
full_args_dict = copy(defaults)
|
full_args_dict = copy(defaults)
|
||||||
full_args_dict.update(args_dict)
|
full_args_dict.update(args_dict)
|
||||||
defaults.update(defaults_)
|
defaults.update(defaults_)
|
||||||
@ -101,17 +115,19 @@ class _Arguments(object):
|
|||||||
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or ''
|
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or ''
|
||||||
for choice in sub_parser.choices.values():
|
for choice in sub_parser.choices.values():
|
||||||
# recursively parse
|
# recursively parse
|
||||||
defaults = cls._add_to_defaults(
|
defaults, descriptions = cls._add_to_defaults(
|
||||||
a_parser=choice,
|
a_parser=choice,
|
||||||
defaults=defaults,
|
defaults=defaults,
|
||||||
|
descriptions=descriptions,
|
||||||
a_parsed_args=a_parsed_args or full_args_dict
|
a_parsed_args=a_parsed_args or full_args_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
return defaults
|
return defaults, descriptions
|
||||||
|
|
||||||
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
||||||
task_defaults = {}
|
task_defaults = {}
|
||||||
self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args)
|
task_defaults_descriptions = {}
|
||||||
|
self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args)
|
||||||
|
|
||||||
# Make sure we didn't miss anything
|
# Make sure we didn't miss anything
|
||||||
if parsed_args:
|
if parsed_args:
|
||||||
@ -133,12 +149,13 @@ class _Arguments(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
del task_defaults[k]
|
del task_defaults[k]
|
||||||
|
|
||||||
# Skip excluded arguments, Add prefix, TODO: add argparse prefix
|
# Skip excluded arguments, Add prefix.
|
||||||
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()
|
task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()
|
||||||
# if k not in self._exclude_parser_args])
|
if self._exclude_parser_args.get(k, True)])
|
||||||
task_defaults = dict([(k, v) for k, v in task_defaults.items() if self._exclude_parser_args.get(k, True)])
|
task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items()
|
||||||
|
if self._exclude_parser_args.get(k, True)])
|
||||||
# Store to task
|
# Store to task
|
||||||
self._task.update_parameters(task_defaults)
|
self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _find_parser_action(cls, a_parser, name):
|
def _find_parser_action(cls, a_parser, name):
|
||||||
@ -158,11 +175,10 @@ class _Arguments(object):
|
|||||||
return _actions
|
return _actions
|
||||||
|
|
||||||
def copy_to_parser(self, parser, parsed_args):
|
def copy_to_parser(self, parser, parsed_args):
|
||||||
# todo: change to argparse prefix only
|
# Change to argparse prefix only
|
||||||
# task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
|
task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
|
||||||
# if k.startswith(self._prefix_args)])
|
if k.startswith(self._prefix_args) and
|
||||||
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items()
|
self._exclude_parser_args.get(k[len(self._prefix_args):], True)])
|
||||||
if not k.startswith(self._prefix_tf_defines) and self._exclude_parser_args.get(k, True)])
|
|
||||||
arg_parser_argeuments = {}
|
arg_parser_argeuments = {}
|
||||||
for k, v in task_arguments.items():
|
for k, v in task_arguments.items():
|
||||||
# python2 unicode support
|
# python2 unicode support
|
||||||
@ -304,25 +320,24 @@ class _Arguments(object):
|
|||||||
parser.set_defaults(**arg_parser_argeuments)
|
parser.set_defaults(**arg_parser_argeuments)
|
||||||
|
|
||||||
def copy_from_dict(self, dictionary, prefix=None):
|
def copy_from_dict(self, dictionary, prefix=None):
|
||||||
# TODO: add dict prefix
|
# add dict prefix
|
||||||
prefix = prefix or '' # self._prefix_dict
|
prefix = prefix # or self._prefix_dict
|
||||||
if prefix:
|
if prefix:
|
||||||
with self._task._edit_lock:
|
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
|
||||||
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
|
# this will only set the specific section
|
||||||
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
|
self._task.set_parameters(dictionary, __parameters_prefix=prefix)
|
||||||
cur_params.update(prefix_dictionary)
|
|
||||||
self._task.set_parameters(cur_params)
|
|
||||||
else:
|
else:
|
||||||
self._task.update_parameters(dictionary)
|
self._task.update_parameters(dictionary)
|
||||||
if not isinstance(dictionary, self._ProxyDictWrite):
|
if not isinstance(dictionary, self._ProxyDictWrite):
|
||||||
return self._ProxyDictWrite(self, **dictionary)
|
return self._ProxyDictWrite(self, prefix, **dictionary)
|
||||||
return dictionary
|
return dictionary
|
||||||
|
|
||||||
def copy_to_dict(self, dictionary, prefix=None):
|
def copy_to_dict(self, dictionary, prefix=None):
|
||||||
# iterate over keys and merge values according to parameter type in dictionary
|
# iterate over keys and merge values according to parameter type in dictionary
|
||||||
# TODO: add dict prefix
|
# add dict prefix
|
||||||
prefix = prefix or '' # self._prefix_dict
|
prefix = prefix # or self._prefix_dict
|
||||||
if prefix:
|
if prefix:
|
||||||
|
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
|
||||||
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
||||||
if k.startswith(prefix)])
|
if k.startswith(prefix)])
|
||||||
else:
|
else:
|
||||||
@ -396,5 +411,5 @@ class _Arguments(object):
|
|||||||
# dictionary[k] = v
|
# dictionary[k] = v
|
||||||
|
|
||||||
if not isinstance(dictionary, self._ProxyDictReadOnly):
|
if not isinstance(dictionary, self._ProxyDictReadOnly):
|
||||||
return self._ProxyDictReadOnly(self, **dictionary)
|
return self._ProxyDictReadOnly(self, prefix, **dictionary)
|
||||||
return dictionary
|
return dictionary
|
||||||
|
@ -30,13 +30,15 @@ from ...backend_api import Session
|
|||||||
from ...backend_api.services import tasks, models, events, projects
|
from ...backend_api.services import tasks, models, events, projects
|
||||||
from ...backend_api.session.defs import ENV_OFFLINE_MODE
|
from ...backend_api.session.defs import ENV_OFFLINE_MODE
|
||||||
from ...utilities.pyhocon import ConfigTree, ConfigFactory
|
from ...utilities.pyhocon import ConfigTree, ConfigFactory
|
||||||
|
from ...utilities.config import config_dict_to_text, text_to_config_dict
|
||||||
|
|
||||||
from ..base import IdObjectBase, InterfaceBase
|
from ..base import IdObjectBase, InterfaceBase
|
||||||
from ..metrics import Metrics, Reporter
|
from ..metrics import Metrics, Reporter
|
||||||
from ..model import Model
|
from ..model import Model
|
||||||
from ..setupuploadmixin import SetupUploadMixin
|
from ..setupuploadmixin import SetupUploadMixin
|
||||||
from ..util import make_message, get_or_create_project, get_single_result, \
|
from ..util import (
|
||||||
exact_match_regex
|
make_message, get_or_create_project, get_single_result,
|
||||||
|
exact_match_regex, mutually_exclusive, )
|
||||||
from ...config import (
|
from ...config import (
|
||||||
get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend,
|
get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend,
|
||||||
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir)
|
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir)
|
||||||
@ -768,12 +770,52 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
|
|
||||||
self._edit(execution=self.data.execution)
|
self._edit(execution=self.data.execution)
|
||||||
|
|
||||||
|
def get_parameters(self, backwards_compatibility=True):
|
||||||
|
# type: (bool) -> (Optional[dict])
|
||||||
|
"""
|
||||||
|
Get the parameters for a Task. This method returns a complete group of key-value parameter pairs, but does not
|
||||||
|
support parameter descriptions (the result is a dictionary of key-value pairs).
|
||||||
|
:param backwards_compatibility: If True (default) parameters without section name
|
||||||
|
(API version < 2.9, trains-server < 0.16) will be at dict root level.
|
||||||
|
If False, parameters without section name, will be nested under "general/" key.
|
||||||
|
:return: dict of the task parameters, all flattened to key/value.
|
||||||
|
Different sections with key prefix "section/"
|
||||||
|
"""
|
||||||
|
if not Session.check_min_api_version('2.9'):
|
||||||
|
return self._get_task_property('execution.parameters')
|
||||||
|
|
||||||
|
# API will makes sure we get old parameters under _legacy
|
||||||
|
parameters = dict()
|
||||||
|
hyperparams = self._get_task_property('hyperparams', default=dict())
|
||||||
|
if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility:
|
||||||
|
for section in ('_legacy', ):
|
||||||
|
for key, section_param in hyperparams[section].items():
|
||||||
|
parameters['{}'.format(key)] = section_param.value
|
||||||
|
else:
|
||||||
|
for section in hyperparams:
|
||||||
|
for key, section_param in hyperparams[section].items():
|
||||||
|
parameters['{}/{}'.format(section, key)] = section_param.value
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
def set_parameters(self, *args, **kwargs):
|
def set_parameters(self, *args, **kwargs):
|
||||||
# type: (*dict, **Any) -> ()
|
# type: (*dict, **Any) -> ()
|
||||||
"""
|
"""
|
||||||
Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not
|
Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not
|
||||||
support parameter descriptions (the input is a dictionary of key-value pairs).
|
support parameter descriptions (the input is a dictionary of key-value pairs).
|
||||||
|
|
||||||
|
:param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are
|
||||||
|
merged into a single key-value pair dictionary.
|
||||||
|
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
|
||||||
|
"""
|
||||||
|
return self._set_parameters(*args, __update=False, **kwargs)
|
||||||
|
|
||||||
|
def _set_parameters(self, *args, **kwargs):
|
||||||
|
# type: (*dict, **Any) -> ()
|
||||||
|
"""
|
||||||
|
Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not
|
||||||
|
support parameter descriptions (the input is a dictionary of key-value pairs).
|
||||||
|
|
||||||
:param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are
|
:param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are
|
||||||
merged into a single key-value pair dictionary.
|
merged into a single key-value pair dictionary.
|
||||||
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
|
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
|
||||||
@ -781,20 +823,22 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
if not all(isinstance(x, (dict, Iterable)) for x in args):
|
if not all(isinstance(x, (dict, Iterable)) for x in args):
|
||||||
raise ValueError('only dict or iterable are supported as positional arguments')
|
raise ValueError('only dict or iterable are supported as positional arguments')
|
||||||
|
|
||||||
|
prefix = kwargs.pop('__parameters_prefix', None)
|
||||||
|
descriptions = kwargs.pop('__parameters_descriptions', None) or dict()
|
||||||
|
params_types = kwargs.pop('__parameters_types', None) or dict()
|
||||||
update = kwargs.pop('__update', False)
|
update = kwargs.pop('__update', False)
|
||||||
|
|
||||||
with self._edit_lock:
|
# new parameters dict
|
||||||
self.reload()
|
new_parameters = dict(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
|
||||||
if update:
|
new_parameters.update(kwargs)
|
||||||
parameters = self.get_parameters()
|
if prefix:
|
||||||
else:
|
prefix = prefix.strip('/')
|
||||||
parameters = dict()
|
new_parameters = dict(('{}/{}'.format(prefix, k), v) for k, v in new_parameters.items())
|
||||||
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
|
|
||||||
parameters.update(kwargs)
|
|
||||||
|
|
||||||
|
# verify parameters type:
|
||||||
not_allowed = {
|
not_allowed = {
|
||||||
k: type(v).__name__
|
k: type(v).__name__
|
||||||
for k, v in parameters.items()
|
for k, v in new_parameters.items()
|
||||||
if not isinstance(v, self._parameters_allowed_types)
|
if not isinstance(v, self._parameters_allowed_types)
|
||||||
}
|
}
|
||||||
if not_allowed:
|
if not_allowed:
|
||||||
@ -804,9 +848,59 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
', '.join('%s=>%s' % p for p in not_allowed.items())),
|
', '.join('%s=>%s' % p for p in not_allowed.items())),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_hyperparams = Session.check_min_api_version('2.9')
|
||||||
|
|
||||||
|
with self._edit_lock:
|
||||||
|
self.reload()
|
||||||
|
# if we have a specific prefix and we use hyperparameters, and we use set.
|
||||||
|
# overwrite only the prefix, leave the rest as is.
|
||||||
|
if not update and prefix:
|
||||||
|
parameters = dict(**(self.get_parameters() or {}))
|
||||||
|
parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/'))
|
||||||
|
elif update:
|
||||||
|
parameters = dict(**(self.get_parameters() or {}))
|
||||||
|
else:
|
||||||
|
parameters = dict()
|
||||||
|
|
||||||
|
parameters.update(new_parameters)
|
||||||
|
|
||||||
# force cast all variables to strings (so that we can later edit them in UI)
|
# force cast all variables to strings (so that we can later edit them in UI)
|
||||||
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
|
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
|
||||||
|
|
||||||
|
if use_hyperparams:
|
||||||
|
# build nested dict from flat parameters dict:
|
||||||
|
org_hyperparams = self.data.hyperparams or {}
|
||||||
|
hyperparams = dict()
|
||||||
|
|
||||||
|
# if the task is a legacy task, we should put everything back under _legacy
|
||||||
|
if self.data.hyperparams and '_legacy' in self.data.hyperparams:
|
||||||
|
for k, v in parameters.items():
|
||||||
|
section_name, key = '_legacy', k
|
||||||
|
section = hyperparams.get(section_name, dict())
|
||||||
|
description = \
|
||||||
|
descriptions.get(k) or \
|
||||||
|
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
|
||||||
|
param_type = \
|
||||||
|
params_types.get(k) or \
|
||||||
|
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type
|
||||||
|
section[key] = tasks.ParamsItem(
|
||||||
|
section=section_name, name=key, value=v, description=description, type=str(param_type))
|
||||||
|
hyperparams[section_name] = section
|
||||||
|
else:
|
||||||
|
for k, v in parameters.items():
|
||||||
|
org_k = k
|
||||||
|
if '/' not in k:
|
||||||
|
k = 'General/{}'.format(k)
|
||||||
|
section_name, key = k.split('/', 1)
|
||||||
|
section = hyperparams.get(section_name, dict())
|
||||||
|
description = \
|
||||||
|
descriptions.get(org_k) or \
|
||||||
|
org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description
|
||||||
|
section[key] = tasks.ParamsItem\
|
||||||
|
(section=section_name, name=key, value=v, description=description)
|
||||||
|
hyperparams[section_name] = section
|
||||||
|
self._edit(hyperparams=hyperparams)
|
||||||
|
else:
|
||||||
execution = self.data.execution
|
execution = self.data.execution
|
||||||
if execution is None:
|
if execution is None:
|
||||||
execution = tasks.Execution(
|
execution = tasks.Execution(
|
||||||
@ -816,23 +910,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
execution.parameters = parameters
|
execution.parameters = parameters
|
||||||
self._edit(execution=execution)
|
self._edit(execution=execution)
|
||||||
|
|
||||||
def set_parameter(self, name, value, description=None):
|
def set_parameter(self, name, value, description=None, value_type=None):
|
||||||
# type: (str, str, Optional[str]) -> ()
|
# type: (str, str, Optional[str], Optional[Any]) -> ()
|
||||||
"""
|
"""
|
||||||
Set a single Task parameter. This overrides any previous value for this parameter.
|
Set a single Task parameter. This overrides any previous value for this parameter.
|
||||||
|
|
||||||
:param name: The parameter name.
|
:param name: The parameter name.
|
||||||
:param value: The parameter value.
|
:param value: The parameter value.
|
||||||
:param description: The parameter description.
|
:param description: The parameter description.
|
||||||
|
:param value_type: The type of the parameters (cast to string and store)
|
||||||
.. note::
|
|
||||||
The ``description`` is not yet in use.
|
|
||||||
"""
|
"""
|
||||||
|
if not Session.check_min_api_version('2.9'):
|
||||||
# not supported yet
|
# not supported yet
|
||||||
if description:
|
|
||||||
# noinspection PyUnusedLocal
|
|
||||||
description = None
|
description = None
|
||||||
self.set_parameters({name: value}, __update=True)
|
value_type = None
|
||||||
|
|
||||||
|
self._set_parameters(
|
||||||
|
{name: value}, __update=True,
|
||||||
|
__parameters_descriptions={name: description},
|
||||||
|
__parameters_types={name: value_type}
|
||||||
|
)
|
||||||
|
|
||||||
def get_parameter(self, name, default=None):
|
def get_parameter(self, name, default=None):
|
||||||
# type: (str, Any) -> Any
|
# type: (str, Any) -> Any
|
||||||
@ -856,7 +953,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
merged into a single key-value pair dictionary.
|
merged into a single key-value pair dictionary.
|
||||||
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
|
:param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``.
|
||||||
"""
|
"""
|
||||||
self.set_parameters(__update=True, *args, **kwargs)
|
self._set_parameters(*args, __update=True, **kwargs)
|
||||||
|
|
||||||
def set_model_label_enumeration(self, enumeration=None):
|
def set_model_label_enumeration(self, enumeration=None):
|
||||||
# type: (Mapping[str, int]) -> ()
|
# type: (Mapping[str, int]) -> ()
|
||||||
@ -1316,7 +1413,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
|
|
||||||
self._update_requirements('')
|
self._update_requirements('')
|
||||||
|
|
||||||
if Session.check_min_api_version('2.3'):
|
if Session.check_min_api_version('2.9'):
|
||||||
|
self._set_task_property("system_tags", system_tags)
|
||||||
|
self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
|
||||||
|
script=self._data.script, execution=self._data.execution, output_dest='',
|
||||||
|
hyperparams=dict(), configuration=dict())
|
||||||
|
elif Session.check_min_api_version('2.3'):
|
||||||
self._set_task_property("system_tags", system_tags)
|
self._set_task_property("system_tags", system_tags)
|
||||||
self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
|
self._edit(system_tags=self._data.system_tags, comment=self._data.comment,
|
||||||
script=self._data.script, execution=self._data.execution, output_dest='')
|
script=self._data.script, execution=self._data.execution, output_dest='')
|
||||||
@ -1386,6 +1488,67 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
self.data.script = script
|
self.data.script = script
|
||||||
self._edit(script=script)
|
self._edit(script=script)
|
||||||
|
|
||||||
|
def _set_configuration(self, name, description=None, config_type=None, config_text=None, config_dict=None):
|
||||||
|
# type: (str, Optional[str], Optional[str], Optional[str], Optional[Mapping]) -> None
|
||||||
|
"""
|
||||||
|
Set Task configuration text/dict. Multiple configurations are supported.
|
||||||
|
|
||||||
|
:param str name: Configuration name.
|
||||||
|
:param str description: Configuration section description.
|
||||||
|
:param config_text: model configuration (unconstrained text string). usually the content
|
||||||
|
of a configuration file. If `config_text` is not None, `config_dict` must not be provided.
|
||||||
|
:param config_dict: model configuration parameters dictionary.
|
||||||
|
If `config_dict` is not None, `config_text` must not be provided.
|
||||||
|
"""
|
||||||
|
# make sure we have wither dict or text
|
||||||
|
mutually_exclusive(config_dict=config_dict, config_text=config_text)
|
||||||
|
|
||||||
|
if not Session.check_min_api_version('2.9'):
|
||||||
|
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
|
||||||
|
"please upgrade to the latest version")
|
||||||
|
|
||||||
|
if description:
|
||||||
|
description = str(description)
|
||||||
|
config = config_dict_to_text(config_text or config_dict)
|
||||||
|
with self._edit_lock:
|
||||||
|
self.reload()
|
||||||
|
configuration = self.data.configuration or {}
|
||||||
|
configuration[name] = tasks.ConfigurationItem(
|
||||||
|
name=name, value=config, description=description or None, type=config_type or None)
|
||||||
|
self._edit(configuration=configuration)
|
||||||
|
|
||||||
|
def _get_configuration_text(self, name):
|
||||||
|
# type: (str) -> Optional[str]
|
||||||
|
"""
|
||||||
|
Get Task configuration section as text
|
||||||
|
|
||||||
|
:param str name: Configuration name.
|
||||||
|
:return: The Task configuration as text (unconstrained text string).
|
||||||
|
return None if configuration name is not valid.
|
||||||
|
"""
|
||||||
|
if not Session.check_min_api_version('2.9'):
|
||||||
|
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
|
||||||
|
"please upgrade to the latest version")
|
||||||
|
|
||||||
|
configuration = self.data.configuration or {}
|
||||||
|
if not configuration.get(name):
|
||||||
|
return None
|
||||||
|
return configuration[name].value
|
||||||
|
|
||||||
|
def _get_configuration_dict(self, name):
|
||||||
|
# type: (str) -> Optional[dict]
|
||||||
|
"""
|
||||||
|
Get Task configuration section as dictionary
|
||||||
|
|
||||||
|
:param str name: Configuration name.
|
||||||
|
:return: The Task configuration as dictionary.
|
||||||
|
return None if configuration name is not valid.
|
||||||
|
"""
|
||||||
|
config_text = self._get_configuration_text(name)
|
||||||
|
if not config_text:
|
||||||
|
return None
|
||||||
|
return text_to_config_dict(config_text)
|
||||||
|
|
||||||
def get_offline_mode_folder(self):
|
def get_offline_mode_folder(self):
|
||||||
# type: () -> (Optional[Path])
|
# type: () -> (Optional[Path])
|
||||||
"""
|
"""
|
||||||
|
@ -339,7 +339,7 @@ class BaseModel(object):
|
|||||||
def _config_dict_to_text(config):
|
def _config_dict_to_text(config):
|
||||||
if not isinstance(config, six.string_types) and not isinstance(config, dict):
|
if not isinstance(config, six.string_types) and not isinstance(config, dict):
|
||||||
raise ValueError("Model configuration only supports dictionary or string objects")
|
raise ValueError("Model configuration only supports dictionary or string objects")
|
||||||
return config_dict_to_text
|
return config_dict_to_text(config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _text_to_config_dict(text):
|
def _text_to_config_dict(text):
|
||||||
|
@ -124,6 +124,7 @@ class Task(_Task):
|
|||||||
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
||||||
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
||||||
__default_output_uri = config.get('development.default_output_uri', None)
|
__default_output_uri = config.get('development.default_output_uri', None)
|
||||||
|
__default_configuration_name = 'General'
|
||||||
|
|
||||||
class _ConnectedParametersType(object):
|
class _ConnectedParametersType(object):
|
||||||
argparse = "argument_parser"
|
argparse = "argument_parser"
|
||||||
@ -904,8 +905,8 @@ class Task(_Task):
|
|||||||
self.data.tags.extend(tags)
|
self.data.tags.extend(tags)
|
||||||
self._edit(tags=list(set(self.data.tags)))
|
self._edit(tags=list(set(self.data.tags)))
|
||||||
|
|
||||||
def connect(self, mutable):
|
def connect(self, mutable, name=None):
|
||||||
# type: (Any) -> Any
|
# type: (Any, Optional[str]) -> Any
|
||||||
"""
|
"""
|
||||||
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
|
Connect an object to a Task object. This connects an experiment component (part of an experiment) to the
|
||||||
experiment. For example, connect hyperparameters or models.
|
experiment. For example, connect hyperparameters or models.
|
||||||
@ -918,6 +919,12 @@ class Task(_Task):
|
|||||||
- TaskParameters - A TaskParameters object.
|
- TaskParameters - A TaskParameters object.
|
||||||
- model - A model object for initial model warmup, or for model update/snapshot uploading.
|
- model - A model object for initial model warmup, or for model update/snapshot uploading.
|
||||||
|
|
||||||
|
:param str name: A section name associated with the connected object. Default: 'General'
|
||||||
|
Currently only supported for `dict` / `TaskParameter` objects
|
||||||
|
Examples:
|
||||||
|
name='General' will put the connected dictionary under the General section in the hyper-parameters
|
||||||
|
name='Train' will put the connected dictionary under the Train section in the hyper-parameters
|
||||||
|
|
||||||
:return: The result returned when connecting the object, if supported.
|
:return: The result returned when connecting the object, if supported.
|
||||||
|
|
||||||
:raise: Raise an exception on unsupported objects.
|
:raise: Raise an exception on unsupported objects.
|
||||||
@ -931,14 +938,22 @@ class Task(_Task):
|
|||||||
(TaskParameters, self._connect_task_parameters),
|
(TaskParameters, self._connect_task_parameters),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
multi_config_support = Session.check_min_api_version('2.9')
|
||||||
|
if multi_config_support and not name:
|
||||||
|
name = self.__default_configuration_name
|
||||||
|
|
||||||
|
if not multi_config_support and name and name != self.__default_configuration_name:
|
||||||
|
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
|
||||||
|
"please upgrade to the latest version")
|
||||||
|
|
||||||
for mutable_type, method in dispatch:
|
for mutable_type, method in dispatch:
|
||||||
if isinstance(mutable, mutable_type):
|
if isinstance(mutable, mutable_type):
|
||||||
return method(mutable)
|
return method(mutable, name=name)
|
||||||
|
|
||||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||||
|
|
||||||
def connect_configuration(self, configuration):
|
def connect_configuration(self, configuration, name=None, description=None):
|
||||||
# type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str]
|
# type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str]
|
||||||
"""
|
"""
|
||||||
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
|
Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object.
|
||||||
This method should be called before reading the configuration file.
|
This method should be called before reading the configuration file.
|
||||||
@ -968,25 +983,54 @@ class Task(_Task):
|
|||||||
A local path must be relative path. When executing a Task remotely in a worker, the contents brought
|
A local path must be relative path. When executing a Task remotely in a worker, the contents brought
|
||||||
from the **Trains Server** (backend) overwrites the contents of the file.
|
from the **Trains Server** (backend) overwrites the contents of the file.
|
||||||
|
|
||||||
|
:param str name: Configuration section name. default: 'General'
|
||||||
|
Allowing users to store multiple configuration dicts/files
|
||||||
|
|
||||||
|
:param str description: Configuration section description (text). default: None
|
||||||
|
|
||||||
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
|
:return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is
|
||||||
specified, then a path to a local configuration file is returned. Configuration object.
|
specified, then a path to a local configuration file is returned. Configuration object.
|
||||||
"""
|
"""
|
||||||
|
pathlib_Path = None
|
||||||
if not isinstance(configuration, (dict, Path, six.string_types)):
|
if not isinstance(configuration, (dict, Path, six.string_types)):
|
||||||
|
try:
|
||||||
|
from pathlib import Path as pathlib_Path
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
if not pathlib_Path or not isinstance(configuration, pathlib_Path):
|
||||||
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
|
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
|
||||||
"{} is not supported".format(type(configuration)))
|
"{} is not supported".format(type(configuration)))
|
||||||
|
|
||||||
|
multi_config_support = Session.check_min_api_version('2.9')
|
||||||
|
if multi_config_support and not name:
|
||||||
|
name = self.__default_configuration_name
|
||||||
|
|
||||||
|
if not multi_config_support and name and name != self.__default_configuration_name:
|
||||||
|
raise ValueError("Multiple configurations are not supported with the current 'trains-server', "
|
||||||
|
"please upgrade to the latest version")
|
||||||
|
|
||||||
# parameter dictionary
|
# parameter dictionary
|
||||||
if isinstance(configuration, dict):
|
if isinstance(configuration, dict):
|
||||||
def _update_config_dict(task, config_dict):
|
def _update_config_dict(task, config_dict):
|
||||||
|
if multi_config_support:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
task._set_configuration(
|
||||||
|
name=name, description=description, config_type='dictionary', config_dict=config_dict)
|
||||||
|
else:
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
task._set_model_config(config_dict=config_dict)
|
task._set_model_config(config_dict=config_dict)
|
||||||
|
|
||||||
if not running_remotely() or not self.is_main_task():
|
if not running_remotely() or not self.is_main_task():
|
||||||
|
if multi_config_support:
|
||||||
|
self._set_configuration(
|
||||||
|
name=name, description=description, config_type='dictionary', config_dict=configuration)
|
||||||
|
else:
|
||||||
self._set_model_config(config_dict=configuration)
|
self._set_model_config(config_dict=configuration)
|
||||||
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
|
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
|
||||||
else:
|
else:
|
||||||
configuration.clear()
|
configuration.clear()
|
||||||
configuration.update(self._get_model_config_dict())
|
configuration.update(self._get_configuration_dict(name=name) if multi_config_support
|
||||||
|
else self._get_model_config_dict())
|
||||||
configuration = ProxyDictPreWrite(False, False, **configuration)
|
configuration = ProxyDictPreWrite(False, False, **configuration)
|
||||||
return configuration
|
return configuration
|
||||||
|
|
||||||
@ -1002,16 +1046,26 @@ class Task(_Task):
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Could not connect configuration file {}, file could not be read".format(
|
raise ValueError("Could not connect configuration file {}, file could not be read".format(
|
||||||
configuration_path.as_posix()))
|
configuration_path.as_posix()))
|
||||||
|
if multi_config_support:
|
||||||
|
self._set_configuration(
|
||||||
|
name=name, description=description,
|
||||||
|
config_type=configuration_path.suffixes[-1].lstrip('.')
|
||||||
|
if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file',
|
||||||
|
config_text=configuration_text)
|
||||||
|
else:
|
||||||
self._set_model_config(config_text=configuration_text)
|
self._set_model_config(config_text=configuration_text)
|
||||||
return configuration
|
return configuration
|
||||||
else:
|
else:
|
||||||
configuration_text = self._get_model_config_text()
|
configuration_text = self._get_configuration_text(name=name) if multi_config_support \
|
||||||
|
else self._get_model_config_text()
|
||||||
configuration_path = Path(configuration)
|
configuration_path = Path(configuration)
|
||||||
fd, local_filename = mkstemp(prefix='trains_task_config_',
|
fd, local_filename = mkstemp(prefix='trains_task_config_',
|
||||||
suffix=configuration_path.suffixes[-1] if
|
suffix=configuration_path.suffixes[-1] if
|
||||||
configuration_path.suffixes else '.txt')
|
configuration_path.suffixes else '.txt')
|
||||||
os.write(fd, configuration_text.encode('utf-8'))
|
os.write(fd, configuration_text.encode('utf-8'))
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
|
if pathlib_Path:
|
||||||
|
return pathlib_Path(local_filename)
|
||||||
return Path(local_filename) if isinstance(configuration, Path) else local_filename
|
return Path(local_filename) if isinstance(configuration, Path) else local_filename
|
||||||
|
|
||||||
def connect_label_enumeration(self, enumeration):
|
def connect_label_enumeration(self, enumeration):
|
||||||
@ -1651,6 +1705,7 @@ class Task(_Task):
|
|||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
|
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
remote_url = task._get_default_report_storage_uri()
|
remote_url = task._get_default_report_storage_uri()
|
||||||
if remote_url and remote_url.endswith('/'):
|
if remote_url and remote_url.endswith('/'):
|
||||||
remote_url = remote_url[:-1]
|
remote_url = remote_url[:-1]
|
||||||
@ -1977,7 +2032,7 @@ class Task(_Task):
|
|||||||
|
|
||||||
return self._logger
|
return self._logger
|
||||||
|
|
||||||
def _connect_output_model(self, model):
|
def _connect_output_model(self, model, name=None):
|
||||||
assert isinstance(model, OutputModel)
|
assert isinstance(model, OutputModel)
|
||||||
model.connect(self)
|
model.connect(self)
|
||||||
return model
|
return model
|
||||||
@ -2001,7 +2056,7 @@ class Task(_Task):
|
|||||||
if self._connected_output_model:
|
if self._connected_output_model:
|
||||||
self.connect(self._connected_output_model)
|
self.connect(self._connected_output_model)
|
||||||
|
|
||||||
def _connect_input_model(self, model):
|
def _connect_input_model(self, model, name=None):
|
||||||
assert isinstance(model, InputModel)
|
assert isinstance(model, InputModel)
|
||||||
# we only allow for an input model to be connected once
|
# we only allow for an input model to be connected once
|
||||||
# at least until we support multiple input models
|
# at least until we support multiple input models
|
||||||
@ -2039,7 +2094,7 @@ class Task(_Task):
|
|||||||
# added support for multiple type connections through _Arguments
|
# added support for multiple type connections through _Arguments
|
||||||
return option
|
return option
|
||||||
|
|
||||||
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None):
|
||||||
# do not allow argparser to connect to jupyter notebook
|
# do not allow argparser to connect to jupyter notebook
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -2079,15 +2134,15 @@ class Task(_Task):
|
|||||||
parser, args=args, namespace=namespace, parsed_args=parsed_args)
|
parser, args=args, namespace=namespace, parsed_args=parsed_args)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def _connect_dictionary(self, dictionary):
|
def _connect_dictionary(self, dictionary, name=None):
|
||||||
def _update_args_dict(task, config_dict):
|
def _update_args_dict(task, config_dict):
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict))
|
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
|
||||||
|
|
||||||
def _refresh_args_dict(task, config_dict):
|
def _refresh_args_dict(task, config_dict):
|
||||||
# reread from task including newly added keys
|
# reread from task including newly added keys
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
|
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
nested_dict = config_dict._to_dict()
|
nested_dict = config_dict._to_dict()
|
||||||
config_dict.clear()
|
config_dict.clear()
|
||||||
@ -2096,23 +2151,28 @@ class Task(_Task):
|
|||||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
||||||
|
|
||||||
if not running_remotely() or not self.is_main_task():
|
if not running_remotely() or not self.is_main_task():
|
||||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
|
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
|
||||||
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
|
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
|
||||||
else:
|
else:
|
||||||
flat_dict = flatten_dictionary(dictionary)
|
flat_dict = flatten_dictionary(dictionary)
|
||||||
flat_dict = self._arguments.copy_to_dict(flat_dict)
|
flat_dict = self._arguments.copy_to_dict(flat_dict, prefix=name)
|
||||||
dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
|
dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
|
||||||
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
|
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
|
||||||
|
|
||||||
return dictionary
|
return dictionary
|
||||||
|
|
||||||
def _connect_task_parameters(self, attr_class):
|
def _connect_task_parameters(self, attr_class, name=None):
|
||||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
|
self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters)
|
||||||
|
|
||||||
if running_remotely() and self.is_main_task():
|
if running_remotely() and self.is_main_task():
|
||||||
attr_class.update_from_dict(self.get_parameters())
|
parameters = self.get_parameters()
|
||||||
|
if not name:
|
||||||
|
attr_class.update_from_dict(parameters)
|
||||||
else:
|
else:
|
||||||
self.set_parameters(attr_class.to_dict())
|
attr_class.update_from_dict(
|
||||||
|
dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name))))
|
||||||
|
else:
|
||||||
|
self.set_parameters(attr_class.to_dict(), __parameters_prefix=name)
|
||||||
return attr_class
|
return attr_class
|
||||||
|
|
||||||
def _validate(self, check_output_dest_credentials=False):
|
def _validate(self, check_output_dest_credentials=False):
|
||||||
|
Loading…
Reference in New Issue
Block a user