Support absl command line arguments

This commit is contained in:
allegroai 2019-06-17 01:00:06 +03:00
parent fc43ca5ed8
commit 005e521da1

View File

@ -5,6 +5,7 @@ from ..config import running_remotely
class PatchAbsl(object):
_original_DEFINE_flag = None
_original_FLAGS_parse_call = None
_task = None
@classmethod
@ -21,13 +22,27 @@ class PatchAbsl(object):
from absl.flags import _defines
cls._original_DEFINE_flag = _defines.DEFINE_flag
_defines.DEFINE_flag = cls._patched_define_flag
# if absl was already set, let's update our task params
from absl import flags
cls._update_current_flags(flags.FLAGS)
except Exception:
# there is no absl
pass
try:
from absl.flags._flagvalues import FlagValues
cls._original_FLAGS_parse_call = FlagValues.__call__
FlagValues.__call__ = cls._patched_FLAGS_parse_call
except Exception:
# there is no absl
pass
if cls._original_DEFINE_flag:
try:
# if absl was already set, let's update our task params
from absl import flags
cls._update_current_flags(flags.FLAGS)
except Exception:
# there is no absl
pass
@staticmethod
def _patched_define_flag(*args, **kwargs):
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
@ -63,6 +78,16 @@ class PatchAbsl(object):
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
return ret
@staticmethod
def _patched_FLAGS_parse_call(self, *args, **kwargs):
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
# noinspection PyBroadException
try:
PatchAbsl._update_current_flags(self)
except Exception:
pass
return ret
@classmethod
def _update_current_flags(cls, FLAGS):
if not cls._task: