mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add multi configuration section support (hyperparams and configurations)
Support setting offline mode API version using TRAINS_OFFLINE_MODE env var
This commit is contained in:
		
							parent
							
								
									6d4e85de0a
								
							
						
					
					
						commit
						e378de1e41
					
				| @ -64,7 +64,11 @@ class DataModel(object): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _to_base_type(cls, value): | ||||
|         if isinstance(value, DataModel): | ||||
|         if isinstance(value, dict): | ||||
|             # Note: this should come before DataModel to handle data models that are simply a dict | ||||
|             # (and thus are not expected to have additional named properties) | ||||
|             return {k: cls._to_base_type(v) for k, v in value.items()} | ||||
|         elif isinstance(value, DataModel): | ||||
|             return value.to_dict() | ||||
|         elif isinstance(value, enum.Enum): | ||||
|             return value.value | ||||
|  | ||||
| @ -62,6 +62,7 @@ class Session(TokenManager): | ||||
|     default_files = "https://demofiles.trains.allegro.ai" | ||||
|     default_key = "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" | ||||
|     default_secret = "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" | ||||
|     force_max_api_version = None | ||||
| 
 | ||||
|     # TODO: add requests.codes.gateway_timeout once we support async commits | ||||
|     _retry_codes = [ | ||||
| @ -182,6 +183,9 @@ class Session(TokenManager): | ||||
| 
 | ||||
|         self.__class__._sessions_created += 1 | ||||
| 
 | ||||
|         if self.force_max_api_version and self.check_min_api_version(self.force_max_api_version): | ||||
|             Session.api_version = str(self.force_max_api_version) | ||||
| 
 | ||||
|     def _send_request( | ||||
|         self, | ||||
|         service, | ||||
| @ -539,7 +543,17 @@ class Session(TokenManager): | ||||
|         # If no session was created, create a default one, in order to get the backend api version. | ||||
|         if cls._sessions_created <= 0: | ||||
|             if cls._offline_mode: | ||||
|                 cls.api_version = cls._offline_default_version | ||||
|                 # allow to change the offline mode version by setting ENV_OFFLINE_MODE to the required API version | ||||
|                 if cls.api_version != cls._offline_default_version: | ||||
|                     offline_api = ENV_OFFLINE_MODE.get(converter=lambda x: x) | ||||
|                     if offline_api: | ||||
|                         try: | ||||
|                             # check cast to float, but leave original str if we pass it. | ||||
|                             float(offline_api) | ||||
|                             cls._offline_default_version = str(offline_api) | ||||
|                         except ValueError: | ||||
|                             pass | ||||
|                     cls.api_version = cls._offline_default_version | ||||
|             else: | ||||
|                 # noinspection PyBroadException | ||||
|                 try: | ||||
|  | ||||
| @ -10,32 +10,35 @@ from ...utilities.args import call_original_argparser | ||||
| class _Arguments(object): | ||||
|     _prefix_sep = '/' | ||||
|     # TODO: separate dict and argparse after we add UI support | ||||
|     _prefix_dict = 'dict' + _prefix_sep | ||||
|     _prefix_args = 'argparse' + _prefix_sep | ||||
|     _prefix_args = 'Args' + _prefix_sep | ||||
|     _prefix_tf_defines = 'TF_DEFINE' + _prefix_sep | ||||
| 
 | ||||
|     class _ProxyDictWrite(dict): | ||||
|         """ Dictionary wrapper that updates an arguments instance on any item set in the dictionary """ | ||||
| 
 | ||||
|         def __init__(self, arguments, *args, **kwargs): | ||||
|         def __init__(self, __arguments, __section_name, *args, **kwargs): | ||||
|             super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs) | ||||
|             self._arguments = arguments | ||||
|             self._arguments = __arguments | ||||
|             self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \ | ||||
|                 if __section_name else None | ||||
| 
 | ||||
|         def __setitem__(self, key, value): | ||||
|             super(_Arguments._ProxyDictWrite, self).__setitem__(key, value) | ||||
|             if self._arguments: | ||||
|                 self._arguments.copy_from_dict(self) | ||||
|                 self._arguments.copy_from_dict(self, prefix=self._section_name) | ||||
| 
 | ||||
|     class _ProxyDictReadOnly(dict): | ||||
|         """ Dictionary wrapper that prevents modifications to the dictionary """ | ||||
| 
 | ||||
|         def __init__(self, arguments, *args, **kwargs): | ||||
|         def __init__(self, __arguments, __section_name, *args, **kwargs): | ||||
|             super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs) | ||||
|             self._arguments = arguments | ||||
|             self._arguments = __arguments | ||||
|             self._section_name = (__section_name.strip(_Arguments._prefix_sep) + _Arguments._prefix_sep) \ | ||||
|                 if __section_name else None | ||||
| 
 | ||||
|         def __setitem__(self, key, value): | ||||
|             if self._arguments: | ||||
|                 param_dict = self._arguments.copy_to_dict({key: value}) | ||||
|                 param_dict = self._arguments.copy_to_dict({key: value}, prefix=self._section_name) | ||||
|                 value = param_dict.get(key, value) | ||||
|             super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value) | ||||
| 
 | ||||
| @ -45,24 +48,32 @@ class _Arguments(object): | ||||
|         self._exclude_parser_args = {} | ||||
| 
 | ||||
|     def exclude_parser_args(self, excluded_args): | ||||
|         """ | ||||
|         You can use a dictionary for fined grained control of connected | ||||
|         arguments. The dictionary keys are argparse variable names and the values are booleans. | ||||
|         The ``False`` value excludes the specified argument from the Task's parameter section. | ||||
|         Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``. | ||||
| 
 | ||||
|         :param excluded_args: dict | ||||
|         """ | ||||
|         self._exclude_parser_args = excluded_args or {} | ||||
| 
 | ||||
|     def set_defaults(self, *dicts, **kwargs): | ||||
|         self._task.set_parameters(*dicts, **kwargs) | ||||
|         # noinspection PyProtectedMember | ||||
|         self._task._set_parameters(*dicts, __parameters_prefix=self._prefix_args, **kwargs) | ||||
| 
 | ||||
|     def add_argument(self, option_strings, type=None, default=None, help=None): | ||||
|         if not option_strings: | ||||
|             raise Exception('Expected at least one argument name (option string)') | ||||
|         name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t') | ||||
|         # TODO: add argparse prefix | ||||
|         # name = self._prefix_args + name | ||||
|         name = self._prefix_args + name | ||||
|         self._task.set_parameter(name=name, value=default, description=help) | ||||
| 
 | ||||
|     def connect(self, parser): | ||||
|         self._task.connect_argparse(parser) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None): | ||||
|     def _add_to_defaults(cls, a_parser, defaults, descriptions, a_args=None, a_namespace=None, a_parsed_args=None): | ||||
|         actions = [ | ||||
|             a for a in a_parser._actions | ||||
|             if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction) | ||||
| @ -87,6 +98,9 @@ class _Arguments(object): | ||||
|                 for a in actions | ||||
|             } | ||||
| 
 | ||||
|         desc_ = {a.dest: a.help for a in actions} | ||||
|         descriptions.update(desc_) | ||||
| 
 | ||||
|         full_args_dict = copy(defaults) | ||||
|         full_args_dict.update(args_dict) | ||||
|         defaults.update(defaults_) | ||||
| @ -101,17 +115,19 @@ class _Arguments(object): | ||||
|                 defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest) or '' | ||||
|             for choice in sub_parser.choices.values(): | ||||
|                 # recursively parse | ||||
|                 defaults = cls._add_to_defaults( | ||||
|                 defaults, descriptions = cls._add_to_defaults( | ||||
|                     a_parser=choice, | ||||
|                     defaults=defaults, | ||||
|                     descriptions=descriptions, | ||||
|                     a_parsed_args=a_parsed_args or full_args_dict | ||||
|                 ) | ||||
| 
 | ||||
|         return defaults | ||||
|         return defaults, descriptions | ||||
| 
 | ||||
|     def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None): | ||||
|         task_defaults = {} | ||||
|         self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args) | ||||
|         task_defaults_descriptions = {} | ||||
|         self._add_to_defaults(parser, task_defaults, task_defaults_descriptions, args, namespace, parsed_args) | ||||
| 
 | ||||
|         # Make sure we didn't miss anything | ||||
|         if parsed_args: | ||||
| @ -133,12 +149,13 @@ class _Arguments(object): | ||||
|             except Exception: | ||||
|                 del task_defaults[k] | ||||
| 
 | ||||
|         # Skip excluded arguments, Add prefix, TODO: add argparse prefix | ||||
|         # task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() | ||||
|         #                       if k not in self._exclude_parser_args]) | ||||
|         task_defaults = dict([(k, v) for k, v in task_defaults.items() if self._exclude_parser_args.get(k, True)]) | ||||
|         # Skip excluded arguments, Add prefix. | ||||
|         task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items() | ||||
|                               if self._exclude_parser_args.get(k, True)]) | ||||
|         task_defaults_descriptions = dict([(self._prefix_args + k, v) for k, v in task_defaults_descriptions.items() | ||||
|                                            if self._exclude_parser_args.get(k, True)]) | ||||
|         # Store to task | ||||
|         self._task.update_parameters(task_defaults) | ||||
|         self._task.update_parameters(task_defaults, __parameters_descriptions=task_defaults_descriptions) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _find_parser_action(cls, a_parser, name): | ||||
| @ -158,11 +175,10 @@ class _Arguments(object): | ||||
|         return _actions | ||||
| 
 | ||||
|     def copy_to_parser(self, parser, parsed_args): | ||||
|         # todo: change to argparse prefix only | ||||
|         # task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() | ||||
|         #                        if k.startswith(self._prefix_args)]) | ||||
|         task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items() | ||||
|                                if not k.startswith(self._prefix_tf_defines) and self._exclude_parser_args.get(k, True)]) | ||||
|         # Change to argparse prefix only | ||||
|         task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items() | ||||
|                                if k.startswith(self._prefix_args) and | ||||
|                                self._exclude_parser_args.get(k[len(self._prefix_args):], True)]) | ||||
|         arg_parser_argeuments = {} | ||||
|         for k, v in task_arguments.items(): | ||||
|             # python2 unicode support | ||||
| @ -304,25 +320,24 @@ class _Arguments(object): | ||||
|         parser.set_defaults(**arg_parser_argeuments) | ||||
| 
 | ||||
|     def copy_from_dict(self, dictionary, prefix=None): | ||||
|         # TODO: add dict prefix | ||||
|         prefix = prefix or ''  # self._prefix_dict | ||||
|         # add dict prefix | ||||
|         prefix = prefix  # or self._prefix_dict | ||||
|         if prefix: | ||||
|             with self._task._edit_lock: | ||||
|                 prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()]) | ||||
|                 cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)]) | ||||
|                 cur_params.update(prefix_dictionary) | ||||
|                 self._task.set_parameters(cur_params) | ||||
|             prefix = prefix.strip(self._prefix_sep) + self._prefix_sep | ||||
|             # this will only set the specific section | ||||
|             self._task.set_parameters(dictionary, __parameters_prefix=prefix) | ||||
|         else: | ||||
|             self._task.update_parameters(dictionary) | ||||
|         if not isinstance(dictionary, self._ProxyDictWrite): | ||||
|             return self._ProxyDictWrite(self, **dictionary) | ||||
|             return self._ProxyDictWrite(self, prefix, **dictionary) | ||||
|         return dictionary | ||||
| 
 | ||||
|     def copy_to_dict(self, dictionary, prefix=None): | ||||
|         # iterate over keys and merge values according to parameter type in dictionary | ||||
|         # TODO: add dict prefix | ||||
|         prefix = prefix or ''  # self._prefix_dict | ||||
|         # add dict prefix | ||||
|         prefix = prefix  # or self._prefix_dict | ||||
|         if prefix: | ||||
|             prefix = prefix.strip(self._prefix_sep) + self._prefix_sep | ||||
|             parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() | ||||
|                                if k.startswith(prefix)]) | ||||
|         else: | ||||
| @ -396,5 +411,5 @@ class _Arguments(object): | ||||
|         #         dictionary[k] = v | ||||
| 
 | ||||
|         if not isinstance(dictionary, self._ProxyDictReadOnly): | ||||
|             return self._ProxyDictReadOnly(self, **dictionary) | ||||
|             return self._ProxyDictReadOnly(self, prefix, **dictionary) | ||||
|         return dictionary | ||||
|  | ||||
| @ -30,13 +30,15 @@ from ...backend_api import Session | ||||
| from ...backend_api.services import tasks, models, events, projects | ||||
| from ...backend_api.session.defs import ENV_OFFLINE_MODE | ||||
| from ...utilities.pyhocon import ConfigTree, ConfigFactory | ||||
| from ...utilities.config import config_dict_to_text, text_to_config_dict | ||||
| 
 | ||||
| from ..base import IdObjectBase, InterfaceBase | ||||
| from ..metrics import Metrics, Reporter | ||||
| from ..model import Model | ||||
| from ..setupuploadmixin import SetupUploadMixin | ||||
| from ..util import make_message, get_or_create_project, get_single_result, \ | ||||
|     exact_match_regex | ||||
| from ..util import ( | ||||
|     make_message, get_or_create_project, get_single_result, | ||||
|     exact_match_regex, mutually_exclusive, ) | ||||
| from ...config import ( | ||||
|     get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, | ||||
|     running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir) | ||||
| @ -768,12 +770,52 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
| 
 | ||||
|             self._edit(execution=self.data.execution) | ||||
| 
 | ||||
|     def get_parameters(self, backwards_compatibility=True): | ||||
|         # type: (bool) -> (Optional[dict]) | ||||
|         """ | ||||
|         Get the parameters for a Task. This method returns a complete group of key-value parameter pairs, but does not | ||||
|         support parameter descriptions (the result is a dictionary of key-value pairs). | ||||
|         :param backwards_compatibility: If True (default) parameters without section name | ||||
|             (API version < 2.9, trains-server < 0.16) will be at dict root level. | ||||
|             If False, parameters without section name, will be nested under "general/" key. | ||||
|         :return: dict of the task parameters, all flattened to key/value. | ||||
|             Different sections with key prefix "section/" | ||||
|         """ | ||||
|         if not Session.check_min_api_version('2.9'): | ||||
|             return self._get_task_property('execution.parameters') | ||||
| 
 | ||||
|         # API will makes sure we get old parameters under _legacy | ||||
|         parameters = dict() | ||||
|         hyperparams = self._get_task_property('hyperparams', default=dict()) | ||||
|         if len(hyperparams) == 1 and '_legacy' in hyperparams and backwards_compatibility: | ||||
|             for section in ('_legacy', ): | ||||
|                 for key, section_param in hyperparams[section].items(): | ||||
|                     parameters['{}'.format(key)] = section_param.value | ||||
|         else: | ||||
|             for section in hyperparams: | ||||
|                 for key, section_param in hyperparams[section].items(): | ||||
|                     parameters['{}/{}'.format(section, key)] = section_param.value | ||||
| 
 | ||||
|         return parameters | ||||
| 
 | ||||
|     def set_parameters(self, *args, **kwargs): | ||||
|         # type: (*dict, **Any) -> () | ||||
|         """ | ||||
|         Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not | ||||
|         support parameter descriptions (the input is a dictionary of key-value pairs). | ||||
| 
 | ||||
|         :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are | ||||
|             merged into a single key-value pair dictionary. | ||||
|         :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. | ||||
|         """ | ||||
|         return self._set_parameters(*args, __update=False, **kwargs) | ||||
| 
 | ||||
|     def _set_parameters(self, *args, **kwargs): | ||||
|         # type: (*dict, **Any) -> () | ||||
|         """ | ||||
|         Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not | ||||
|         support parameter descriptions (the input is a dictionary of key-value pairs). | ||||
| 
 | ||||
|         :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are | ||||
|             merged into a single key-value pair dictionary. | ||||
|         :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. | ||||
| @ -781,58 +823,113 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|         if not all(isinstance(x, (dict, Iterable)) for x in args): | ||||
|             raise ValueError('only dict or iterable are supported as positional arguments') | ||||
| 
 | ||||
|         prefix = kwargs.pop('__parameters_prefix', None) | ||||
|         descriptions = kwargs.pop('__parameters_descriptions', None) or dict() | ||||
|         params_types = kwargs.pop('__parameters_types', None) or dict() | ||||
|         update = kwargs.pop('__update', False) | ||||
| 
 | ||||
|         # new parameters dict | ||||
|         new_parameters = dict(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) | ||||
|         new_parameters.update(kwargs) | ||||
|         if prefix: | ||||
|             prefix = prefix.strip('/') | ||||
|             new_parameters = dict(('{}/{}'.format(prefix, k), v) for k, v in new_parameters.items()) | ||||
| 
 | ||||
|         # verify parameters type: | ||||
|         not_allowed = { | ||||
|             k: type(v).__name__ | ||||
|             for k, v in new_parameters.items() | ||||
|             if not isinstance(v, self._parameters_allowed_types) | ||||
|         } | ||||
|         if not_allowed: | ||||
|             raise ValueError( | ||||
|                 "Only builtin types ({}) are allowed for values (got {})".format( | ||||
|                     ', '.join(t.__name__ for t in self._parameters_allowed_types), | ||||
|                     ', '.join('%s=>%s' % p for p in not_allowed.items())), | ||||
|             ) | ||||
| 
 | ||||
|         use_hyperparams = Session.check_min_api_version('2.9') | ||||
| 
 | ||||
|         with self._edit_lock: | ||||
|             self.reload() | ||||
|             if update: | ||||
|                 parameters = self.get_parameters() | ||||
|             # if we have a specific prefix and we use hyperparameters, and we use set. | ||||
|             # overwrite only the prefix, leave the rest as is. | ||||
|             if not update and prefix: | ||||
|                 parameters = dict(**(self.get_parameters() or {})) | ||||
|                 parameters = dict((k, v) for k, v in parameters.items() if not k.startswith(prefix+'/')) | ||||
|             elif update: | ||||
|                 parameters = dict(**(self.get_parameters() or {})) | ||||
|             else: | ||||
|                 parameters = dict() | ||||
|             parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) | ||||
|             parameters.update(kwargs) | ||||
| 
 | ||||
|             not_allowed = { | ||||
|                 k: type(v).__name__ | ||||
|                 for k, v in parameters.items() | ||||
|                 if not isinstance(v, self._parameters_allowed_types) | ||||
|             } | ||||
|             if not_allowed: | ||||
|                 raise ValueError( | ||||
|                     "Only builtin types ({}) are allowed for values (got {})".format( | ||||
|                         ', '.join(t.__name__ for t in self._parameters_allowed_types), | ||||
|                         ', '.join('%s=>%s' % p for p in not_allowed.items())), | ||||
|                 ) | ||||
|             parameters.update(new_parameters) | ||||
| 
 | ||||
|             # force cast all variables to strings (so that we can later edit them in UI) | ||||
|             parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} | ||||
| 
 | ||||
|             execution = self.data.execution | ||||
|             if execution is None: | ||||
|                 execution = tasks.Execution( | ||||
|                     parameters=parameters, artifacts=[], dataviews=[], model='', | ||||
|                     model_desc={}, model_labels={}, docker_cmd='') | ||||
|             else: | ||||
|                 execution.parameters = parameters | ||||
|             self._edit(execution=execution) | ||||
|             if use_hyperparams: | ||||
|                 # build nested dict from flat parameters dict: | ||||
|                 org_hyperparams = self.data.hyperparams or {} | ||||
|                 hyperparams = dict() | ||||
| 
 | ||||
|     def set_parameter(self, name, value, description=None): | ||||
|         # type: (str, str, Optional[str]) -> () | ||||
|                 # if the task is a legacy task, we should put everything back under _legacy | ||||
|                 if self.data.hyperparams and '_legacy' in self.data.hyperparams: | ||||
|                     for k, v in parameters.items(): | ||||
|                         section_name, key = '_legacy', k | ||||
|                         section = hyperparams.get(section_name, dict()) | ||||
|                         description = \ | ||||
|                             descriptions.get(k) or \ | ||||
|                             org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description | ||||
|                         param_type = \ | ||||
|                             params_types.get(k) or \ | ||||
|                             org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).type | ||||
|                         section[key] = tasks.ParamsItem( | ||||
|                             section=section_name, name=key, value=v, description=description, type=str(param_type)) | ||||
|                         hyperparams[section_name] = section | ||||
|                 else: | ||||
|                     for k, v in parameters.items(): | ||||
|                         org_k = k | ||||
|                         if '/' not in k: | ||||
|                             k = 'General/{}'.format(k) | ||||
|                         section_name, key = k.split('/', 1) | ||||
|                         section = hyperparams.get(section_name, dict()) | ||||
|                         description = \ | ||||
|                             descriptions.get(org_k) or \ | ||||
|                             org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()).description | ||||
|                         section[key] = tasks.ParamsItem\ | ||||
|                             (section=section_name, name=key, value=v, description=description) | ||||
|                         hyperparams[section_name] = section | ||||
|                 self._edit(hyperparams=hyperparams) | ||||
|             else: | ||||
|                 execution = self.data.execution | ||||
|                 if execution is None: | ||||
|                     execution = tasks.Execution( | ||||
|                         parameters=parameters, artifacts=[], dataviews=[], model='', | ||||
|                         model_desc={}, model_labels={}, docker_cmd='') | ||||
|                 else: | ||||
|                     execution.parameters = parameters | ||||
|                 self._edit(execution=execution) | ||||
| 
 | ||||
|     def set_parameter(self, name, value, description=None, value_type=None): | ||||
|         # type: (str, str, Optional[str], Optional[Any]) -> () | ||||
|         """ | ||||
|         Set a single Task parameter. This overrides any previous value for this parameter. | ||||
| 
 | ||||
|         :param name: The parameter name. | ||||
|         :param value: The parameter value. | ||||
|         :param description: The parameter description. | ||||
| 
 | ||||
|             .. note:: | ||||
|                The ``description`` is not yet in use. | ||||
|         :param value_type: The type of the parameters (cast to string and store) | ||||
|         """ | ||||
|         # not supported yet | ||||
|         if description: | ||||
|             # noinspection PyUnusedLocal | ||||
|         if not Session.check_min_api_version('2.9'): | ||||
|             # not supported yet | ||||
|             description = None | ||||
|         self.set_parameters({name: value}, __update=True) | ||||
|             value_type = None | ||||
| 
 | ||||
|         self._set_parameters( | ||||
|             {name: value}, __update=True, | ||||
|             __parameters_descriptions={name: description}, | ||||
|             __parameters_types={name: value_type} | ||||
|         ) | ||||
| 
 | ||||
|     def get_parameter(self, name, default=None): | ||||
|         # type: (str, Any) -> Any | ||||
| @ -856,7 +953,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|             merged into a single key-value pair dictionary. | ||||
|         :param kwargs: Key-value pairs, merged into the parameters dictionary created from ``args``. | ||||
|         """ | ||||
|         self.set_parameters(__update=True, *args, **kwargs) | ||||
|         self._set_parameters(*args, __update=True, **kwargs) | ||||
| 
 | ||||
|     def set_model_label_enumeration(self, enumeration=None): | ||||
|         # type: (Mapping[str, int]) -> () | ||||
| @ -1316,7 +1413,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
| 
 | ||||
|         self._update_requirements('') | ||||
| 
 | ||||
|         if Session.check_min_api_version('2.3'): | ||||
|         if Session.check_min_api_version('2.9'): | ||||
|             self._set_task_property("system_tags", system_tags) | ||||
|             self._edit(system_tags=self._data.system_tags, comment=self._data.comment, | ||||
|                        script=self._data.script, execution=self._data.execution, output_dest='', | ||||
|                        hyperparams=dict(), configuration=dict()) | ||||
|         elif Session.check_min_api_version('2.3'): | ||||
|             self._set_task_property("system_tags", system_tags) | ||||
|             self._edit(system_tags=self._data.system_tags, comment=self._data.comment, | ||||
|                        script=self._data.script, execution=self._data.execution, output_dest='') | ||||
| @ -1386,6 +1488,67 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): | ||||
|             self.data.script = script | ||||
|             self._edit(script=script) | ||||
| 
 | ||||
|     def _set_configuration(self, name, description=None, config_type=None, config_text=None, config_dict=None): | ||||
|         # type: (str, Optional[str], Optional[str], Optional[str], Optional[Mapping]) -> None | ||||
|         """ | ||||
|         Set Task configuration text/dict. Multiple configurations are supported. | ||||
| 
 | ||||
|         :param str name: Configuration name. | ||||
|         :param str description: Configuration section description. | ||||
|         :param config_text: model configuration (unconstrained text string). usually the content | ||||
|             of a configuration file. If `config_text` is not None, `config_dict` must not be provided. | ||||
|         :param config_dict: model configuration parameters dictionary. | ||||
|             If `config_dict` is not None, `config_text` must not be provided. | ||||
|         """ | ||||
|         # make sure we have wither dict or text | ||||
|         mutually_exclusive(config_dict=config_dict, config_text=config_text) | ||||
| 
 | ||||
|         if not Session.check_min_api_version('2.9'): | ||||
|             raise ValueError("Multiple configurations are not supported with the current 'trains-server', " | ||||
|                              "please upgrade to the latest version") | ||||
| 
 | ||||
|         if description: | ||||
|             description = str(description) | ||||
|         config = config_dict_to_text(config_text or config_dict) | ||||
|         with self._edit_lock: | ||||
|             self.reload() | ||||
|             configuration = self.data.configuration or {} | ||||
|             configuration[name] = tasks.ConfigurationItem( | ||||
|                 name=name, value=config, description=description or None, type=config_type or None) | ||||
|             self._edit(configuration=configuration) | ||||
| 
 | ||||
|     def _get_configuration_text(self, name): | ||||
|         # type: (str) -> Optional[str] | ||||
|         """ | ||||
|         Get Task configuration section as text | ||||
| 
 | ||||
|         :param str name: Configuration name. | ||||
|         :return: The Task configuration as text (unconstrained text string). | ||||
|             return None if configuration name is not valid. | ||||
|         """ | ||||
|         if not Session.check_min_api_version('2.9'): | ||||
|             raise ValueError("Multiple configurations are not supported with the current 'trains-server', " | ||||
|                              "please upgrade to the latest version") | ||||
| 
 | ||||
|         configuration = self.data.configuration or {} | ||||
|         if not configuration.get(name): | ||||
|             return None | ||||
|         return configuration[name].value | ||||
| 
 | ||||
|     def _get_configuration_dict(self, name): | ||||
|         # type: (str) -> Optional[dict] | ||||
|         """ | ||||
|         Get Task configuration section as dictionary | ||||
| 
 | ||||
|         :param str name: Configuration name. | ||||
|         :return: The Task configuration as dictionary. | ||||
|             return None if configuration name is not valid. | ||||
|         """ | ||||
|         config_text = self._get_configuration_text(name) | ||||
|         if not config_text: | ||||
|             return None | ||||
|         return text_to_config_dict(config_text) | ||||
| 
 | ||||
|     def get_offline_mode_folder(self): | ||||
|         # type: () -> (Optional[Path]) | ||||
|         """ | ||||
|  | ||||
| @ -339,7 +339,7 @@ class BaseModel(object): | ||||
|     def _config_dict_to_text(config): | ||||
|         if not isinstance(config, six.string_types) and not isinstance(config, dict): | ||||
|             raise ValueError("Model configuration only supports dictionary or string objects") | ||||
|         return config_dict_to_text | ||||
|         return config_dict_to_text(config) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _text_to_config_dict(text): | ||||
|  | ||||
							
								
								
									
										108
									
								
								trains/task.py
									
									
									
									
									
								
							
							
						
						
									
										108
									
								
								trains/task.py
									
									
									
									
									
								
							| @ -124,6 +124,7 @@ class Task(_Task): | ||||
|     __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) | ||||
|     __detect_repo_async = config.get('development.vcs_repo_detect_async', False) | ||||
|     __default_output_uri = config.get('development.default_output_uri', None) | ||||
|     __default_configuration_name = 'General' | ||||
| 
 | ||||
|     class _ConnectedParametersType(object): | ||||
|         argparse = "argument_parser" | ||||
| @ -904,8 +905,8 @@ class Task(_Task): | ||||
|             self.data.tags.extend(tags) | ||||
|             self._edit(tags=list(set(self.data.tags))) | ||||
| 
 | ||||
|     def connect(self, mutable): | ||||
|         # type: (Any) -> Any | ||||
|     def connect(self, mutable, name=None): | ||||
|         # type: (Any, Optional[str]) -> Any | ||||
|         """ | ||||
|         Connect an object to a Task object. This connects an experiment component (part of an experiment) to the | ||||
|         experiment. For example, connect hyperparameters or models. | ||||
| @ -918,6 +919,12 @@ class Task(_Task): | ||||
|             - TaskParameters - A TaskParameters object. | ||||
|             - model - A model object for initial model warmup, or for model update/snapshot uploading. | ||||
| 
 | ||||
|         :param str name: A section name associated with the connected object. Default: 'General' | ||||
|             Currently only supported for `dict` / `TaskParameter` objects | ||||
|             Examples: | ||||
|                 name='General' will put the connected dictionary under the General section in the hyper-parameters | ||||
|                 name='Train' will put the connected dictionary under the Train section in the hyper-parameters | ||||
| 
 | ||||
|         :return: The result returned when connecting the object, if supported. | ||||
| 
 | ||||
|         :raise: Raise an exception on unsupported objects. | ||||
| @ -931,14 +938,22 @@ class Task(_Task): | ||||
|             (TaskParameters, self._connect_task_parameters), | ||||
|         ) | ||||
| 
 | ||||
|         multi_config_support = Session.check_min_api_version('2.9') | ||||
|         if multi_config_support and not name: | ||||
|             name = self.__default_configuration_name | ||||
| 
 | ||||
|         if not multi_config_support and name and name != self.__default_configuration_name: | ||||
|             raise ValueError("Multiple configurations are not supported with the current 'trains-server', " | ||||
|                              "please upgrade to the latest version") | ||||
| 
 | ||||
|         for mutable_type, method in dispatch: | ||||
|             if isinstance(mutable, mutable_type): | ||||
|                 return method(mutable) | ||||
|                 return method(mutable, name=name) | ||||
| 
 | ||||
|         raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) | ||||
| 
 | ||||
|     def connect_configuration(self, configuration): | ||||
|         # type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str] | ||||
|     def connect_configuration(self, configuration, name=None, description=None): | ||||
|         # type: (Union[Mapping, Path, str], Optional[str], Optional[str]) -> Union[Mapping, Path, str] | ||||
|         """ | ||||
|         Connect a configuration dictionary or configuration file (pathlib.Path / str) to a Task object. | ||||
|         This method should be called before reading the configuration file. | ||||
| @ -968,25 +983,54 @@ class Task(_Task): | ||||
|               A local path must be relative path. When executing a Task remotely in a worker, the contents brought | ||||
|               from the **Trains Server** (backend) overwrites the contents of the file. | ||||
| 
 | ||||
|         :param str name: Configuration section name. default: 'General' | ||||
|             Allowing users to store multiple configuration dicts/files | ||||
| 
 | ||||
|         :param str description: Configuration section description (text). default: None | ||||
| 
 | ||||
|         :return: If a dictionary is specified, then a dictionary is returned. If pathlib2.Path / string is | ||||
|             specified, then a path to a local configuration file is returned. Configuration object. | ||||
|         """ | ||||
|         pathlib_Path = None | ||||
|         if not isinstance(configuration, (dict, Path, six.string_types)): | ||||
|             raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " | ||||
|                              "{} is not supported".format(type(configuration))) | ||||
|             try: | ||||
|                 from pathlib import Path as pathlib_Path | ||||
|             except ImportError: | ||||
|                 pass | ||||
|             if not pathlib_Path or not isinstance(configuration, pathlib_Path): | ||||
|                 raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " | ||||
|                                  "{} is not supported".format(type(configuration))) | ||||
| 
 | ||||
|         multi_config_support = Session.check_min_api_version('2.9') | ||||
|         if multi_config_support and not name: | ||||
|             name = self.__default_configuration_name | ||||
| 
 | ||||
|         if not multi_config_support and name and name != self.__default_configuration_name: | ||||
|             raise ValueError("Multiple configurations are not supported with the current 'trains-server', " | ||||
|                              "please upgrade to the latest version") | ||||
| 
 | ||||
|         # parameter dictionary | ||||
|         if isinstance(configuration, dict): | ||||
|             def _update_config_dict(task, config_dict): | ||||
|                 # noinspection PyProtectedMember | ||||
|                 task._set_model_config(config_dict=config_dict) | ||||
|                 if multi_config_support: | ||||
|                     # noinspection PyProtectedMember | ||||
|                     task._set_configuration( | ||||
|                         name=name, description=description, config_type='dictionary', config_dict=config_dict) | ||||
|                 else: | ||||
|                     # noinspection PyProtectedMember | ||||
|                     task._set_model_config(config_dict=config_dict) | ||||
| 
 | ||||
|             if not running_remotely() or not self.is_main_task(): | ||||
|                 self._set_model_config(config_dict=configuration) | ||||
|                 if multi_config_support: | ||||
|                     self._set_configuration( | ||||
|                         name=name, description=description, config_type='dictionary', config_dict=configuration) | ||||
|                 else: | ||||
|                     self._set_model_config(config_dict=configuration) | ||||
|                 configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration) | ||||
|             else: | ||||
|                 configuration.clear() | ||||
|                 configuration.update(self._get_model_config_dict()) | ||||
|                 configuration.update(self._get_configuration_dict(name=name) if multi_config_support | ||||
|                                      else self._get_model_config_dict()) | ||||
|                 configuration = ProxyDictPreWrite(False, False, **configuration) | ||||
|             return configuration | ||||
| 
 | ||||
| @ -1002,16 +1046,26 @@ class Task(_Task): | ||||
|             except Exception: | ||||
|                 raise ValueError("Could not connect configuration file {}, file could not be read".format( | ||||
|                     configuration_path.as_posix())) | ||||
|             self._set_model_config(config_text=configuration_text) | ||||
|             if multi_config_support: | ||||
|                 self._set_configuration( | ||||
|                     name=name, description=description, | ||||
|                     config_type=configuration_path.suffixes[-1].lstrip('.') | ||||
|                     if configuration_path.suffixes and configuration_path.suffixes[-1] else 'file', | ||||
|                     config_text=configuration_text) | ||||
|             else: | ||||
|                 self._set_model_config(config_text=configuration_text) | ||||
|             return configuration | ||||
|         else: | ||||
|             configuration_text = self._get_model_config_text() | ||||
|             configuration_text = self._get_configuration_text(name=name) if multi_config_support \ | ||||
|                 else self._get_model_config_text() | ||||
|             configuration_path = Path(configuration) | ||||
|             fd, local_filename = mkstemp(prefix='trains_task_config_', | ||||
|                                          suffix=configuration_path.suffixes[-1] if | ||||
|                                          configuration_path.suffixes else '.txt') | ||||
|             os.write(fd, configuration_text.encode('utf-8')) | ||||
|             os.close(fd) | ||||
|             if pathlib_Path: | ||||
|                 return pathlib_Path(local_filename) | ||||
|             return Path(local_filename) if isinstance(configuration, Path) else local_filename | ||||
| 
 | ||||
|     def connect_label_enumeration(self, enumeration): | ||||
| @ -1651,6 +1705,7 @@ class Task(_Task): | ||||
|             # noinspection PyProtectedMember | ||||
|             offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/') | ||||
| 
 | ||||
|             # noinspection PyProtectedMember | ||||
|             remote_url = task._get_default_report_storage_uri() | ||||
|             if remote_url and remote_url.endswith('/'): | ||||
|                 remote_url = remote_url[:-1] | ||||
| @ -1977,7 +2032,7 @@ class Task(_Task): | ||||
| 
 | ||||
|         return self._logger | ||||
| 
 | ||||
|     def _connect_output_model(self, model): | ||||
|     def _connect_output_model(self, model, name=None): | ||||
|         assert isinstance(model, OutputModel) | ||||
|         model.connect(self) | ||||
|         return model | ||||
| @ -2001,7 +2056,7 @@ class Task(_Task): | ||||
|         if self._connected_output_model: | ||||
|             self.connect(self._connected_output_model) | ||||
| 
 | ||||
|     def _connect_input_model(self, model): | ||||
|     def _connect_input_model(self, model, name=None): | ||||
|         assert isinstance(model, InputModel) | ||||
|         # we only allow for an input model to be connected once | ||||
|         # at least until we support multiple input models | ||||
| @ -2039,7 +2094,7 @@ class Task(_Task): | ||||
|         # added support for multiple type connections through _Arguments | ||||
|         return option | ||||
| 
 | ||||
|     def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None): | ||||
|     def _connect_argparse(self, parser, args=None, namespace=None, parsed_args=None, name=None): | ||||
|         # do not allow argparser to connect to jupyter notebook | ||||
|         # noinspection PyBroadException | ||||
|         try: | ||||
| @ -2079,15 +2134,15 @@ class Task(_Task): | ||||
|                 parser, args=args, namespace=namespace, parsed_args=parsed_args) | ||||
|         return parser | ||||
| 
 | ||||
|     def _connect_dictionary(self, dictionary): | ||||
|     def _connect_dictionary(self, dictionary, name=None): | ||||
|         def _update_args_dict(task, config_dict): | ||||
|             # noinspection PyProtectedMember | ||||
|             task._arguments.copy_from_dict(flatten_dictionary(config_dict)) | ||||
|             task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name) | ||||
| 
 | ||||
|         def _refresh_args_dict(task, config_dict): | ||||
|             # reread from task including newly added keys | ||||
|             # noinspection PyProtectedMember | ||||
|             a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict)) | ||||
|             a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name) | ||||
|             # noinspection PyProtectedMember | ||||
|             nested_dict = config_dict._to_dict() | ||||
|             config_dict.clear() | ||||
| @ -2096,23 +2151,28 @@ class Task(_Task): | ||||
|         self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) | ||||
| 
 | ||||
|         if not running_remotely() or not self.is_main_task(): | ||||
|             self._arguments.copy_from_dict(flatten_dictionary(dictionary)) | ||||
|             self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name) | ||||
|             dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) | ||||
|         else: | ||||
|             flat_dict = flatten_dictionary(dictionary) | ||||
|             flat_dict = self._arguments.copy_to_dict(flat_dict) | ||||
|             flat_dict = self._arguments.copy_to_dict(flat_dict, prefix=name) | ||||
|             dictionary = nested_from_flat_dictionary(dictionary, flat_dict) | ||||
|             dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary) | ||||
| 
 | ||||
|         return dictionary | ||||
| 
 | ||||
|     def _connect_task_parameters(self, attr_class): | ||||
|     def _connect_task_parameters(self, attr_class, name=None): | ||||
|         self._try_set_connected_parameter_type(self._ConnectedParametersType.task_parameters) | ||||
| 
 | ||||
|         if running_remotely() and self.is_main_task(): | ||||
|             attr_class.update_from_dict(self.get_parameters()) | ||||
|             parameters = self.get_parameters() | ||||
|             if not name: | ||||
|                 attr_class.update_from_dict(parameters) | ||||
|             else: | ||||
|                 attr_class.update_from_dict( | ||||
|                     dict((k[len(name)+1:], v) for k, v in parameters.items() if k.startswith('{}/'.format(name)))) | ||||
|         else: | ||||
|             self.set_parameters(attr_class.to_dict()) | ||||
|             self.set_parameters(attr_class.to_dict(), __parameters_prefix=name) | ||||
|         return attr_class | ||||
| 
 | ||||
|     def _validate(self, check_output_dest_credentials=False): | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai