mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
89 lines
3.3 KiB
Python
89 lines
3.3 KiB
Python
""" absl-py FLAGS binding utility functions """
|
|
from trains.backend_interface.task.args import _Arguments
|
|
from ..config import running_remotely
|
|
|
|
|
|
class PatchAbsl(object):
|
|
_original_DEFINE_flag = None
|
|
_task = None
|
|
|
|
@classmethod
|
|
def update_current_task(cls, current_task):
|
|
cls._task = current_task
|
|
cls._patch_absl()
|
|
|
|
@classmethod
|
|
def _patch_absl(cls):
|
|
if cls._original_DEFINE_flag:
|
|
return
|
|
# noinspection PyBroadException
|
|
try:
|
|
from absl.flags import _defines
|
|
cls._original_DEFINE_flag = _defines.DEFINE_flag
|
|
_defines.DEFINE_flag = cls._patched_define_flag
|
|
# if absl was already set, let's update our task params
|
|
from absl import flags
|
|
cls._update_current_flags(flags.FLAGS)
|
|
except Exception:
|
|
# there is no absl
|
|
pass
|
|
|
|
@staticmethod
|
|
def _patched_define_flag(*args, **kwargs):
|
|
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
|
|
if PatchAbsl._original_DEFINE_flag:
|
|
return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
|
else:
|
|
return None
|
|
# noinspection PyBroadException
|
|
try:
|
|
flag = args[0] if len(args) >= 1 else None
|
|
module_name = args[2] if len(args) >= 3 else None
|
|
param_name = None
|
|
if flag:
|
|
param_name = ((module_name + _Arguments._prefix_sep) if module_name else '') + flag.name
|
|
except Exception:
|
|
flag = None
|
|
param_name = None
|
|
|
|
if running_remotely():
|
|
# noinspection PyBroadException
|
|
try:
|
|
if param_name and flag:
|
|
param_dict = PatchAbsl._task._arguments.copy_to_dict({param_name: flag.value},
|
|
prefix=_Arguments._prefix_tf_defines)
|
|
flag.value = param_dict.get(param_name, flag.value)
|
|
except Exception:
|
|
pass
|
|
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
|
else:
|
|
if flag and param_name:
|
|
value = flag.value
|
|
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value})
|
|
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
|
return ret
|
|
|
|
@classmethod
|
|
def _update_current_flags(cls, FLAGS):
|
|
if not cls._task:
|
|
return
|
|
# noinspection PyBroadException
|
|
try:
|
|
if running_remotely():
|
|
param_dict = cls._task._arguments.copy_to_dict({}, prefix=_Arguments._prefix_tf_defines)
|
|
for k, v in param_dict.items():
|
|
# noinspection PyBroadException
|
|
try:
|
|
parts = k.split(_Arguments._prefix_sep)
|
|
k = parts[0]
|
|
if k in FLAGS:
|
|
FLAGS[k].value = v
|
|
except Exception:
|
|
pass
|
|
else:
|
|
# clear previous parameters
|
|
parameters = dict([(k, FLAGS[k].value) for k in FLAGS])
|
|
cls._task._arguments.copy_from_dict(parameters, prefix=_Arguments._prefix_tf_defines)
|
|
except Exception:
|
|
pass
|