Fix Fire integration is not compatible with typing library (#610)

This commit is contained in:
allegroai 2022-04-09 14:22:47 +03:00
parent 5edad33b86
commit 3bde51ebe7
2 changed files with 46 additions and 5 deletions

View File

@ -60,7 +60,7 @@ class PatchFire:
if cls.__current_command is None:
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
for k in (PatchFire.__command_args.get(None) or []):
for k in PatchFire.__command_args.get(None) or []:
k = cls._section_name + cls._args_sep + k
if k not in args:
args[k] = None
@ -220,7 +220,7 @@ class PatchFire:
def __get_all_groups_and_commands(component, context):
groups = []
commands = {}
component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa
component_trace_result = PatchFire.__safe_Fire(component, [], PatchFire.__default_args, context)
group_args = [[]]
while len(group_args) > 0:
query_group = group_args[-1]
@ -240,7 +240,7 @@ class PatchFire:
@staticmethod
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=name)
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
groups = [(name, member) for name, member in groups.GetItems()]
commands = [(name, member) for name, member in commands.GetItems()]
@ -262,10 +262,29 @@ class PatchFire:
@staticmethod
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=None)
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
return fn_spec.args
@staticmethod
def __safe_Fire(component, args_, parsed_flag_args, context, name=None):
orig = None
# noinspection PyBroadException
try:
def __CallAndUpdateTrace_rogue_call_guard(*args, **kwargs):
raise fire.core.FireError()
orig = fire.core._CallAndUpdateTrace # noqa
fire.core._CallAndUpdateTrace = __CallAndUpdateTrace_rogue_call_guard # noqa
result = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
except Exception:
result = None
finally:
if orig:
fire.core._CallAndUpdateTrace = orig # noqa
return result
@staticmethod
def _load_task_params():
if not PatchFire.__remote_task_params:

View File

@ -0,0 +1,22 @@
from typing import Tuple, List
from clearml import Task
import fire
def with_ret() -> Tuple:
print("With ret called")
return 1, 2
def with_args(arg1: int, arg2: List):
print("With args called", arg1, arg2)
def with_args_and_ret(arg1: int, arg2: List) -> Tuple:
print("With args and ret called", arg1, arg2)
return 1, 2
if __name__ == "__main__":
Task.init(project_name="examples", task_name="Fire typing command")
fire.Fire()