""" absl-py FLAGS binding utility functions """ from trains.backend_interface.task.args import _Arguments from ..config import running_remotely class PatchAbsl(object): _original_DEFINE_flag = None _task = None @classmethod def update_current_task(cls, current_task): cls._task = current_task cls._patch_absl() @classmethod def _patch_absl(cls): if cls._original_DEFINE_flag: return # noinspection PyBroadException try: from absl.flags import _defines cls._original_DEFINE_flag = _defines.DEFINE_flag _defines.DEFINE_flag = cls._patched_define_flag # if absl was already set, let's update our task params from absl import flags cls._update_current_flags(flags.FLAGS) except Exception: # there is no absl pass @staticmethod def _patched_define_flag(*args, **kwargs): if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag: if PatchAbsl._original_DEFINE_flag: return PatchAbsl._original_DEFINE_flag(*args, **kwargs) else: return None # noinspection PyBroadException try: flag = args[0] if len(args) >= 1 else None module_name = args[2] if len(args) >= 3 else None param_name = None if flag: param_name = ((module_name + _Arguments._prefix_sep) if module_name else '') + flag.name except Exception: flag = None param_name = None if running_remotely(): # noinspection PyBroadException try: if param_name and flag: param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value}, prefix=_Arguments._prefix_tf_defines) flag.value = param_dict.get(param_name, flag.value) except Exception: pass ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs) else: if flag and param_name: value = flag.value PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}) ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs) return ret @classmethod def _update_current_flags(cls, FLAGS): if not cls._task: return # noinspection PyBroadException try: if running_remotely(): param_dict = cls._task._arguments.copy_to_dict({}, prefix=_Arguments._prefix_tf_defines) for k, v in param_dict.items(): # noinspection PyBroadException try: parts = k.split(_Arguments._prefix_sep) k = parts[0] if k in FLAGS: FLAGS[k].value = v except Exception: pass else: # clear previous parameters parameters = dict([(k, FLAGS[k].value) for k in FLAGS]) cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines) except Exception: pass