mirror of
https://github.com/clearml/clearml
synced 2025-03-03 02:32:11 +00:00
Fix Fire integration is not compatible with typing library (#610)
This commit is contained in:
parent
5edad33b86
commit
3bde51ebe7
@ -60,7 +60,7 @@ class PatchFire:
|
||||
if cls.__current_command is None:
|
||||
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
||||
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
|
||||
for k in (PatchFire.__command_args.get(None) or []):
|
||||
for k in PatchFire.__command_args.get(None) or []:
|
||||
k = cls._section_name + cls._args_sep + k
|
||||
if k not in args:
|
||||
args[k] = None
|
||||
@ -134,7 +134,7 @@ class PatchFire:
|
||||
replaced_args = []
|
||||
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
||||
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
||||
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
|
||||
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep) :])
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
@ -220,7 +220,7 @@ class PatchFire:
|
||||
def __get_all_groups_and_commands(component, context):
|
||||
groups = []
|
||||
commands = {}
|
||||
component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa
|
||||
component_trace_result = PatchFire.__safe_Fire(component, [], PatchFire.__default_args, context)
|
||||
group_args = [[]]
|
||||
while len(group_args) > 0:
|
||||
query_group = group_args[-1]
|
||||
@ -240,7 +240,7 @@ class PatchFire:
|
||||
|
||||
@staticmethod
|
||||
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
|
||||
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=name)
|
||||
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
|
||||
groups = [(name, member) for name, member in groups.GetItems()]
|
||||
commands = [(name, member) for name, member in commands.GetItems()]
|
||||
@ -262,10 +262,29 @@ class PatchFire:
|
||||
|
||||
@staticmethod
|
||||
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
|
||||
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=None)
|
||||
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
|
||||
return fn_spec.args
|
||||
|
||||
@staticmethod
|
||||
def __safe_Fire(component, args_, parsed_flag_args, context, name=None):
|
||||
orig = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
|
||||
def __CallAndUpdateTrace_rogue_call_guard(*args, **kwargs):
|
||||
raise fire.core.FireError()
|
||||
|
||||
orig = fire.core._CallAndUpdateTrace # noqa
|
||||
fire.core._CallAndUpdateTrace = __CallAndUpdateTrace_rogue_call_guard # noqa
|
||||
result = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||
except Exception:
|
||||
result = None
|
||||
finally:
|
||||
if orig:
|
||||
fire.core._CallAndUpdateTrace = orig # noqa
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _load_task_params():
|
||||
if not PatchFire.__remote_task_params:
|
||||
|
22
examples/frameworks/fire/fire_typing.py
Normal file
22
examples/frameworks/fire/fire_typing.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Tuple, List
|
||||
from clearml import Task
|
||||
import fire
|
||||
|
||||
|
||||
def with_ret() -> Tuple:
|
||||
print("With ret called")
|
||||
return 1, 2
|
||||
|
||||
|
||||
def with_args(arg1: int, arg2: List):
|
||||
print("With args called", arg1, arg2)
|
||||
|
||||
|
||||
def with_args_and_ret(arg1: int, arg2: List) -> Tuple:
|
||||
print("With args and ret called", arg1, arg2)
|
||||
return 1, 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="Fire typing command")
|
||||
fire.Fire()
|
Loading…
Reference in New Issue
Block a user