clearml/trains/binding/absl_bind.py

115 lines
4.1 KiB
Python

""" absl-py FLAGS binding utility functions """
from ..backend_interface.task.args import _Arguments
from ..config import running_remotely
class PatchAbsl(object):
_original_DEFINE_flag = None
_original_FLAGS_parse_call = None
_task = None
@classmethod
def update_current_task(cls, current_task):
cls._task = current_task
cls._patch_absl()
@classmethod
def _patch_absl(cls):
if cls._original_DEFINE_flag:
return
# noinspection PyBroadException
try:
from absl.flags import _defines
cls._original_DEFINE_flag = _defines.DEFINE_flag
_defines.DEFINE_flag = cls._patched_define_flag
except Exception:
# there is no absl
pass
try:
from absl.flags._flagvalues import FlagValues
cls._original_FLAGS_parse_call = FlagValues.__call__
FlagValues.__call__ = cls._patched_FLAGS_parse_call
except Exception:
# there is no absl
pass
if cls._original_DEFINE_flag:
try:
# if absl was already set, let's update our task params
from absl import flags
cls._update_current_flags(flags.FLAGS)
except Exception:
# there is no absl
pass
@staticmethod
def _patched_define_flag(*args, **kwargs):
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
if PatchAbsl._original_DEFINE_flag:
return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
else:
return None
# noinspection PyBroadException
try:
flag = args[0] if len(args) >= 1 else None
module_name = args[2] if len(args) >= 3 else None
param_name = None
if flag:
param_name = ((module_name + _Arguments._prefix_sep) if module_name else '') + flag.name
except Exception:
flag = None
param_name = None
if running_remotely():
# noinspection PyBroadException
try:
if param_name and flag:
param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value},
prefix=_Arguments._prefix_tf_defines)
flag.value = param_dict.get(param_name, flag.value)
except Exception:
pass
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
else:
if flag and param_name:
value = flag.value
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value})
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
return ret
@staticmethod
def _patched_FLAGS_parse_call(self, *args, **kwargs):
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
# noinspection PyBroadException
try:
PatchAbsl._update_current_flags(self)
except Exception:
pass
return ret
@classmethod
def _update_current_flags(cls, FLAGS):
if not cls._task:
return
# noinspection PyBroadException
try:
if running_remotely():
param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
for k, v in param_dict.items():
# noinspection PyBroadException
try:
parts = k.split(_Arguments._prefix_sep)
k = parts[0]
if k in FLAGS:
FLAGS[k].value = v
except Exception:
pass
else:
# clear previous parameters
parameters = dict([(k, FLAGS[k].value) for k in FLAGS])
cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines)
except Exception:
pass