mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix Fire integration is not compatible with typing library (#610)
This commit is contained in:
parent
5edad33b86
commit
3bde51ebe7
@ -60,7 +60,7 @@ class PatchFire:
|
|||||||
if cls.__current_command is None:
|
if cls.__current_command is None:
|
||||||
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
||||||
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
|
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
|
||||||
for k in (PatchFire.__command_args.get(None) or []):
|
for k in PatchFire.__command_args.get(None) or []:
|
||||||
k = cls._section_name + cls._args_sep + k
|
k = cls._section_name + cls._args_sep + k
|
||||||
if k not in args:
|
if k not in args:
|
||||||
args[k] = None
|
args[k] = None
|
||||||
@ -220,7 +220,7 @@ class PatchFire:
|
|||||||
def __get_all_groups_and_commands(component, context):
|
def __get_all_groups_and_commands(component, context):
|
||||||
groups = []
|
groups = []
|
||||||
commands = {}
|
commands = {}
|
||||||
component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa
|
component_trace_result = PatchFire.__safe_Fire(component, [], PatchFire.__default_args, context)
|
||||||
group_args = [[]]
|
group_args = [[]]
|
||||||
while len(group_args) > 0:
|
while len(group_args) > 0:
|
||||||
query_group = group_args[-1]
|
query_group = group_args[-1]
|
||||||
@ -240,7 +240,7 @@ class PatchFire:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
|
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
|
||||||
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=name)
|
||||||
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
|
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
|
||||||
groups = [(name, member) for name, member in groups.GetItems()]
|
groups = [(name, member) for name, member in groups.GetItems()]
|
||||||
commands = [(name, member) for name, member in commands.GetItems()]
|
commands = [(name, member) for name, member in commands.GetItems()]
|
||||||
@ -262,10 +262,29 @@ class PatchFire:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
|
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
|
||||||
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
component_trace = PatchFire.__safe_Fire(component, args_, parsed_flag_args, context, name=None)
|
||||||
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
|
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
|
||||||
return fn_spec.args
|
return fn_spec.args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __safe_Fire(component, args_, parsed_flag_args, context, name=None):
|
||||||
|
orig = None
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
|
||||||
|
def __CallAndUpdateTrace_rogue_call_guard(*args, **kwargs):
|
||||||
|
raise fire.core.FireError()
|
||||||
|
|
||||||
|
orig = fire.core._CallAndUpdateTrace # noqa
|
||||||
|
fire.core._CallAndUpdateTrace = __CallAndUpdateTrace_rogue_call_guard # noqa
|
||||||
|
result = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||||
|
except Exception:
|
||||||
|
result = None
|
||||||
|
finally:
|
||||||
|
if orig:
|
||||||
|
fire.core._CallAndUpdateTrace = orig # noqa
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_task_params():
|
def _load_task_params():
|
||||||
if not PatchFire.__remote_task_params:
|
if not PatchFire.__remote_task_params:
|
||||||
|
22
examples/frameworks/fire/fire_typing.py
Normal file
22
examples/frameworks/fire/fire_typing.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from typing import Tuple, List
|
||||||
|
from clearml import Task
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
def with_ret() -> Tuple:
|
||||||
|
print("With ret called")
|
||||||
|
return 1, 2
|
||||||
|
|
||||||
|
|
||||||
|
def with_args(arg1: int, arg2: List):
|
||||||
|
print("With args called", arg1, arg2)
|
||||||
|
|
||||||
|
|
||||||
|
def with_args_and_ret(arg1: int, arg2: List) -> Tuple:
|
||||||
|
print("With args and ret called", arg1, arg2)
|
||||||
|
return 1, 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="Fire typing command")
|
||||||
|
fire.Fire()
|
Loading…
Reference in New Issue
Block a user