Fix Fire integration is not compatible with typing library (#610)

This commit is contained in:
allegroai 2022-04-09 14:22:47 +03:00
parent 5edad33b86
commit 3bde51ebe7
2 changed files with 46 additions and 5 deletions

View File

@ -60,7 +60,7 @@ class PatchFire:
if cls.__current_command is None: if cls.__current_command is None:
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()} parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
for k in (PatchFire.__command_args.get(None) or []): for k in PatchFire.__command_args.get(None) or []:
k = cls._section_name + cls._args_sep + k k = cls._section_name + cls._args_sep + k
if k not in args: if k not in args:
args[k] = None args[k] = None
@ -134,7 +134,7 @@ class PatchFire:
replaced_args = [] replaced_args = []
for param in PatchFire.__remote_task_params[PatchFire._section_name].values(): for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
if command is not None and param.type == PatchFire._command_arg_type_template % command: if command is not None and param.type == PatchFire._command_arg_type_template % command:
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):]) replaced_args.append("--" + param.name[len(command + PatchFire._args_sep) :])
value = PatchFire.__remote_task_params_dict[param.name] value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0: if len(value) > 0:
replaced_args.append(value) replaced_args.append(value)
@ -220,7 +220,7 @@ class PatchFire:
def __get_all_groups_and_commands(component, context): def __get_all_groups_and_commands(component, context):
groups = [] groups = []
commands = {} commands = {}
component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa component_trace_result = PatchFire.__safe_Fire(component, [], PatchFire.__default_args, context)
group_args = [[]] group_args = [[]]
while len(group_args) > 0: while len(group_args) > 0:
query_group = group_args[-1] query_group = group_args[-1]
@ -240,7 +240,7 @@ class PatchFire:
@staticmethod @staticmethod
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None): def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=name)
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
groups = [(name, member) for name, member in groups.GetItems()] groups = [(name, member) for name, member in groups.GetItems()]
commands = [(name, member) for name, member in commands.GetItems()] commands = [(name, member) for name, member in commands.GetItems()]
@ -262,10 +262,29 @@ class PatchFire:
@staticmethod @staticmethod
def __get_command_args(component, args_, parsed_flag_args, context, name=None): def __get_command_args(component, args_, parsed_flag_args, context, name=None):
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=None)
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace) fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
return fn_spec.args return fn_spec.args
@staticmethod
def __safe_Fire(component, args_, parsed_flag_args, context, name=None):
orig = None
# noinspection PyBroadException
try:
def __CallAndUpdateTrace_rogue_call_guard(*args, **kwargs):
raise fire.core.FireError()
orig = fire.core._CallAndUpdateTrace # noqa
fire.core._CallAndUpdateTrace = __CallAndUpdateTrace_rogue_call_guard # noqa
result = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
except Exception:
result = None
finally:
if orig:
fire.core._CallAndUpdateTrace = orig # noqa
return result
@staticmethod @staticmethod
def _load_task_params(): def _load_task_params():
if not PatchFire.__remote_task_params: if not PatchFire.__remote_task_params:

View File

@ -0,0 +1,22 @@
from typing import Tuple, List
from clearml import Task
import fire
def with_ret() -> Tuple:
print("With ret called")
return 1, 2
def with_args(arg1: int, arg2: List):
print("With args called", arg1, arg2)
def with_args_and_ret(arg1: int, arg2: List) -> Tuple:
print("With args and ret called", arg1, arg2)
return 1, 2
if __name__ == "__main__":
Task.init(project_name="examples", task_name="Fire typing command")
fire.Fire()