diff --git a/clearml/binding/fire_bind.py b/clearml/binding/fire_bind.py index d35aea8f..864dce57 100644 --- a/clearml/binding/fire_bind.py +++ b/clearml/binding/fire_bind.py @@ -60,7 +60,7 @@ class PatchFire: if cls.__current_command is None: args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()} - for k in (PatchFire.__command_args.get(None) or []): + for k in PatchFire.__command_args.get(None) or []: k = cls._section_name + cls._args_sep + k if k not in args: args[k] = None @@ -134,7 +134,7 @@ class PatchFire: replaced_args = [] for param in PatchFire.__remote_task_params[PatchFire._section_name].values(): if command is not None and param.type == PatchFire._command_arg_type_template % command: - replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):]) + replaced_args.append("--" + param.name[len(command + PatchFire._args_sep) :]) value = PatchFire.__remote_task_params_dict[param.name] if len(value) > 0: replaced_args.append(value) @@ -220,7 +220,7 @@ class PatchFire: def __get_all_groups_and_commands(component, context): groups = [] commands = {} - component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa + component_trace_result = PatchFire.__safe_Fire(component, [], PatchFire.__default_args, context) group_args = [[]] while len(group_args) > 0: query_group = group_args[-1] @@ -240,7 +240,7 @@ class PatchFire: @staticmethod def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None): - component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa + component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=name) groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa groups = [(name, member) for name, member in groups.GetItems()] commands = [(name, member) for name, member in commands.GetItems()] @@ -262,10 +262,29 @@ class PatchFire: @staticmethod def __get_command_args(component, args_, parsed_flag_args, context, name=None): - component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa + component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=None) fn_spec = fire.inspectutils.GetFullArgSpec(component_trace) return fn_spec.args + @staticmethod + def __safe_Fire(component, args_, parsed_flag_args, context, name=None): + orig = None + # noinspection PyBroadException + try: + + def __CallAndUpdateTrace_rogue_call_guard(*args, **kwargs): + raise fire.core.FireError() + + orig = fire.core._CallAndUpdateTrace # noqa + fire.core._CallAndUpdateTrace = __CallAndUpdateTrace_rogue_call_guard # noqa + result = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa + except Exception: + result = None + finally: + if orig: + fire.core._CallAndUpdateTrace = orig # noqa + return result + @staticmethod def _load_task_params(): if not PatchFire.__remote_task_params: diff --git a/examples/frameworks/fire/fire_typing.py b/examples/frameworks/fire/fire_typing.py new file mode 100644 index 00000000..e53741bd --- /dev/null +++ b/examples/frameworks/fire/fire_typing.py @@ -0,0 +1,22 @@ +from typing import Tuple, List +from clearml import Task +import fire + + +def with_ret() -> Tuple: + print("With ret called") + return 1, 2 + + +def with_args(arg1: int, arg2: List): + print("With args called", arg1, arg2) + + +def with_args_and_ret(arg1: int, arg2: List) -> Tuple: + print("With args and ret called", arg1, arg2) + return 1, 2 + + +if __name__ == "__main__": + Task.init(project_name="examples", task_name="Fire typing command") + fire.Fire()