mirror of
https://github.com/clearml/clearml
synced 2025-04-25 08:51:16 +00:00
Add Python Fire support (#550)
This commit is contained in:
parent
9794135aed
commit
296cb7d899
294
clearml/binding/fire_bind.py
Normal file
294
clearml/binding/fire_bind.py
Normal file
@ -0,0 +1,294 @@
|
||||
try:
|
||||
import fire
|
||||
import fire.core
|
||||
import fire.helptext
|
||||
except ImportError:
|
||||
fire = None
|
||||
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
from .frameworks import _patched_call # noqa
|
||||
from ..config import get_remote_task_id, running_remotely
|
||||
from ..utilities.dicts import cast_str_to_bool
|
||||
|
||||
|
||||
class PatchFire:
|
||||
_args = {}
|
||||
_command_type = "fire.Command"
|
||||
_command_arg_type_template = "fire.Arg@%s"
|
||||
_shared_arg_type = "fire.Arg.shared"
|
||||
_section_name = "Args"
|
||||
_args_sep = "/"
|
||||
_commands_sep = "."
|
||||
_main_task = None
|
||||
__remote_task_params = None
|
||||
__remote_task_params_dict = {}
|
||||
__patched = False
|
||||
__groups = []
|
||||
__commands = {}
|
||||
__default_args = SimpleNamespace(
|
||||
completion=None, help=False, interactive=False, separator="-", trace=False, verbose=False
|
||||
)
|
||||
__current_command = None
|
||||
__fetched_current_command = False
|
||||
__command_args = {}
|
||||
|
||||
@classmethod
|
||||
def patch(cls, task=None):
|
||||
if fire is None:
|
||||
return
|
||||
|
||||
if task:
|
||||
cls._main_task = task
|
||||
cls._update_task_args()
|
||||
|
||||
if not cls.__patched:
|
||||
cls.__patched = True
|
||||
if running_remotely():
|
||||
fire.core._Fire = _patched_call(fire.core._Fire, PatchFire.__Fire)
|
||||
else:
|
||||
fire.core._CallAndUpdateTrace = _patched_call(
|
||||
fire.core._CallAndUpdateTrace, PatchFire.__CallAndUpdateTrace
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _update_task_args(cls):
|
||||
if running_remotely() or not cls._main_task:
|
||||
return
|
||||
args = {}
|
||||
parameters_types = {}
|
||||
if cls.__current_command is None:
|
||||
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
||||
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
|
||||
for k in (PatchFire.__command_args.get(None) or []):
|
||||
k = cls._section_name + cls._args_sep + k
|
||||
if k not in args:
|
||||
args[k] = None
|
||||
else:
|
||||
args[cls._section_name + cls._args_sep + cls.__current_command] = True
|
||||
parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type
|
||||
args = {
|
||||
**args,
|
||||
**{
|
||||
cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v
|
||||
for k, v in cls._args.items()
|
||||
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||
},
|
||||
**{
|
||||
cls._section_name + cls._args_sep + k: v
|
||||
for k, v in cls._args.items()
|
||||
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||
},
|
||||
}
|
||||
parameters_types = {
|
||||
**parameters_types,
|
||||
**{
|
||||
cls._section_name
|
||||
+ cls._args_sep
|
||||
+ cls.__current_command
|
||||
+ cls._args_sep
|
||||
+ k: cls._command_arg_type_template % cls.__current_command
|
||||
for k in cls._args.keys()
|
||||
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||
},
|
||||
**{
|
||||
cls._section_name + cls._args_sep + k: cls._shared_arg_type
|
||||
for k in cls._args.keys()
|
||||
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||
},
|
||||
}
|
||||
for command in cls.__commands:
|
||||
if command == cls.__current_command:
|
||||
continue
|
||||
args[cls._section_name + cls._args_sep + command] = False
|
||||
parameters_types[cls._section_name + cls._args_sep + command] = cls._command_type
|
||||
unused_command_args = {
|
||||
cls._section_name + cls._args_sep + command + cls._args_sep + k: None
|
||||
for k in (cls.__command_args.get(command) or [])
|
||||
}
|
||||
unused_paramenters_types = {
|
||||
cls._section_name
|
||||
+ cls._args_sep
|
||||
+ command
|
||||
+ cls._args_sep
|
||||
+ k: cls._command_arg_type_template % command
|
||||
for k in (cls.__command_args.get(command) or [])
|
||||
}
|
||||
args = {**args, **unused_command_args}
|
||||
parameters_types = {**parameters_types, **unused_paramenters_types}
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
cls._main_task._set_parameters(
|
||||
args,
|
||||
__update=True,
|
||||
__parameters_types=parameters_types,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
|
||||
if running_remotely():
|
||||
command = PatchFire._load_task_params()
|
||||
if command is not None:
|
||||
replaced_args = command.split(PatchFire._commands_sep)
|
||||
else:
|
||||
replaced_args = []
|
||||
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
||||
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
||||
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep) :])
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
if param.type == PatchFire._shared_arg_type:
|
||||
replaced_args.append("--" + param.name)
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
|
||||
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def __CallAndUpdateTrace( # noqa
|
||||
original_fn, component, args_, component_trace, treatment, target, *args, **kwargs
|
||||
):
|
||||
if running_remotely():
|
||||
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)
|
||||
if not PatchFire.__fetched_current_command:
|
||||
PatchFire.__fetched_current_command = True
|
||||
context, component_context = PatchFire.__get_context_and_component(component)
|
||||
PatchFire.__groups, PatchFire.__commands = PatchFire.__get_all_groups_and_commands(
|
||||
component_context, context
|
||||
)
|
||||
PatchFire.__current_command = PatchFire.__get_current_command(
|
||||
args_, PatchFire.__groups, PatchFire.__commands
|
||||
)
|
||||
for command in PatchFire.__commands:
|
||||
PatchFire.__command_args[command] = PatchFire.__get_command_args(
|
||||
component_context, command.split(PatchFire._commands_sep), PatchFire.__default_args, context
|
||||
)
|
||||
PatchFire.__command_args[None] = PatchFire.__get_command_args(
|
||||
component_context,
|
||||
"",
|
||||
PatchFire.__default_args,
|
||||
context,
|
||||
)
|
||||
for k, v in PatchFire.__commands.items():
|
||||
if v == component:
|
||||
PatchFire.__current_command = k
|
||||
break
|
||||
# Comparing methods in Python is equivalent to comparing the __func__ of the methods
|
||||
# and the objects they are bound to. We do not care about the object in this case,
|
||||
# so we just compare the __func__
|
||||
if inspect.ismethod(component) and inspect.ismethod(v) and v.__func__ == component.__func__:
|
||||
PatchFire.__current_command = k
|
||||
break
|
||||
fn = component.__call__ if treatment == "callable" else component
|
||||
metadata = fire.decorators.GetMetadata(component)
|
||||
fn_spec = fire.inspectutils.GetFullArgSpec(component)
|
||||
parse = fire.core._MakeParseFn(fn, metadata) # noqa
|
||||
(parsed_args, parsed_kwargs), _, _, _ = parse(args_)
|
||||
PatchFire._args = {**PatchFire._args, **{k: v for k, v in zip(fn_spec.args, parsed_args)}, **parsed_kwargs}
|
||||
PatchFire._update_task_args()
|
||||
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def __get_context_and_component(component):
|
||||
context = {}
|
||||
component_context = component
|
||||
# Walk through the stack to find the arguments with fire.Fire() has been called.
|
||||
# Can't do it by patching the function because we want to patch _CallAndUpdateTrace,
|
||||
# which is called by fire.Fire()
|
||||
frame_infos = inspect.stack()
|
||||
for frame_info_ind, frame_info in enumerate(frame_infos):
|
||||
if frame_info.function == "Fire":
|
||||
component_context = inspect.getargvalues(frame_info.frame).locals["component"]
|
||||
if inspect.getargvalues(frame_info.frame).locals["component"] is None:
|
||||
# This is similar to how fire finds this context
|
||||
fire_context_frame = frame_infos[frame_info_ind + 1].frame
|
||||
context.update(fire_context_frame.f_globals)
|
||||
context.update(fire_context_frame.f_locals)
|
||||
# Ignore modules, as they yield too many commands.
|
||||
# Also ignore clearml.task.
|
||||
context = {
|
||||
k: v
|
||||
for k, v in context.items()
|
||||
if not inspect.ismodule(v) and (not inspect.isclass(v) or v.__module__ != "clearml.task")
|
||||
}
|
||||
break
|
||||
return context, component_context
|
||||
|
||||
@staticmethod
|
||||
def __get_all_groups_and_commands(component, context):
|
||||
groups = []
|
||||
commands = {}
|
||||
component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa
|
||||
group_args = [[]]
|
||||
while len(group_args) > 0:
|
||||
query_group = group_args[-1]
|
||||
groups.append(PatchFire._commands_sep.join(query_group))
|
||||
group_args = group_args[:-1]
|
||||
current_groups, current_commands = PatchFire.__get_groups_and_commands_for_args(
|
||||
component_trace_result, query_group, PatchFire.__default_args, context
|
||||
)
|
||||
for command in current_commands:
|
||||
prefix = (
|
||||
PatchFire._commands_sep.join(query_group) + PatchFire._commands_sep if len(query_group) > 0 else ""
|
||||
)
|
||||
commands[prefix + command[0]] = command[1]
|
||||
for group in current_groups:
|
||||
group_args.append(query_group + [group[0]])
|
||||
return groups, commands
|
||||
|
||||
@staticmethod
|
||||
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
|
||||
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
|
||||
groups = [(name, member) for name, member in groups.GetItems()]
|
||||
commands = [(name, member) for name, member in commands.GetItems()]
|
||||
return groups, commands
|
||||
|
||||
@staticmethod
|
||||
def __get_current_command(args_, groups, commands):
|
||||
current_command = ""
|
||||
for arg in args_:
|
||||
prefix = (current_command + PatchFire._commands_sep) if len(current_command) > 0 else ""
|
||||
potential_current_command = prefix + arg
|
||||
if potential_current_command not in groups:
|
||||
if potential_current_command in commands:
|
||||
return potential_current_command
|
||||
else:
|
||||
return None
|
||||
current_command = potential_current_command
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
|
||||
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
|
||||
return fn_spec.args
|
||||
|
||||
@staticmethod
|
||||
def _load_task_params():
|
||||
if not PatchFire.__remote_task_params:
|
||||
from clearml import Task
|
||||
|
||||
t = Task.get_task(task_id=get_remote_task_id())
|
||||
# noinspection PyProtectedMember
|
||||
PatchFire.__remote_task_params = t._get_task_property("hyperparams") or {}
|
||||
params_dict = t.get_parameters(backwards_compatibility=False)
|
||||
skip = len(PatchFire._section_name) + 1
|
||||
PatchFire.__remote_task_params_dict = {
|
||||
k[skip:]: v
|
||||
for k, v in params_dict.items()
|
||||
if k.startswith(PatchFire._section_name + PatchFire._args_sep)
|
||||
}
|
||||
|
||||
command = [
|
||||
p.name
|
||||
for p in PatchFire.__remote_task_params[PatchFire._section_name].values()
|
||||
if p.type == PatchFire._command_type and cast_str_to_bool(p.value, strip=True)
|
||||
]
|
||||
return command[0] if command else None
|
||||
|
||||
|
||||
# patch fire before anything
|
||||
PatchFire.patch()
|
@ -50,6 +50,7 @@ from .binding.joblib_bind import PatchedJoblib
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .binding.hydra_bind import PatchHydra
|
||||
from .binding.click_bind import PatchClick
|
||||
from .binding.fire_bind import PatchFire
|
||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||
from .config import (
|
||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||
@ -606,8 +607,9 @@ class Task(_Task):
|
||||
|
||||
# Patch ArgParser to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
# Patch Click
|
||||
# Patch Click and Fire
|
||||
PatchClick.patch(Task.__main_task)
|
||||
PatchFire.patch(Task.__main_task)
|
||||
|
||||
# set excluded arguments
|
||||
if isinstance(auto_connect_arg_parser, dict):
|
||||
|
21
examples/frameworks/fire/fire_class_cmd.py
Normal file
21
examples/frameworks/fire/fire_class_cmd.py
Normal file
@ -0,0 +1,21 @@
|
||||
# ClearML - Example of Python Fire integration, processing commands derived from a class
|
||||
#
|
||||
from clearml import Task
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
class BrokenCalculator(object):
|
||||
def __init__(self, offset=1):
|
||||
self._offset = offset
|
||||
|
||||
def add(self, x, y):
|
||||
return x + y + self._offset
|
||||
|
||||
def multiply(self, x, y):
|
||||
return x * y + self._offset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="fire class command")
|
||||
fire.Fire(BrokenCalculator)
|
23
examples/frameworks/fire/fire_dict_cmd.py
Normal file
23
examples/frameworks/fire/fire_dict_cmd.py
Normal file
@ -0,0 +1,23 @@
|
||||
# ClearML - Example of Python Fire integration, with commands derived from a dictionary
|
||||
#
|
||||
from clearml import Task
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
def multiply(x, y):
|
||||
return x * y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="fire dict command")
|
||||
fire.Fire(
|
||||
{
|
||||
"add": add,
|
||||
"multiply": multiply,
|
||||
}
|
||||
)
|
44
examples/frameworks/fire/fire_grouping_cmd.py
Normal file
44
examples/frameworks/fire/fire_grouping_cmd.py
Normal file
@ -0,0 +1,44 @@
|
||||
# ClearML - Example of Python Fire integration, with commands grouped inside classes
|
||||
#
|
||||
from clearml import Task
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
class Other(object):
|
||||
def status(self):
|
||||
return "Other"
|
||||
|
||||
|
||||
class IngestionStage(object):
|
||||
def __init__(self):
|
||||
self.other = Other()
|
||||
|
||||
def run(self):
|
||||
return "Ingesting! Nom nom nom..."
|
||||
|
||||
def hello(self, hello_str):
|
||||
return hello_str
|
||||
|
||||
|
||||
class DigestionStage(object):
|
||||
def run(self, volume=1):
|
||||
return " ".join(["Burp!"] * volume)
|
||||
|
||||
def status(self):
|
||||
return "Satiated."
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
def __init__(self):
|
||||
self.ingestion = IngestionStage()
|
||||
self.digestion = DigestionStage()
|
||||
|
||||
def run(self):
|
||||
self.ingestion.run()
|
||||
self.digestion.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="fire grouping command")
|
||||
fire.Fire(Pipeline)
|
22
examples/frameworks/fire/fire_multi_cmd.py
Normal file
22
examples/frameworks/fire/fire_multi_cmd.py
Normal file
@ -0,0 +1,22 @@
|
||||
# ClearML - Example of Python Fire integration, processing multiple commands, when fire is initialized with no component
|
||||
#
|
||||
from clearml import Task
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
def hello(count, name="clearml", prefix="prefix_", suffix="_suffix", **kwargs):
|
||||
for _ in range(count):
|
||||
print("Hello %s%s%s!" % (prefix, name, suffix))
|
||||
|
||||
|
||||
def serve(addr, port, should_serve=False):
|
||||
if not should_serve:
|
||||
print("Not serving")
|
||||
else:
|
||||
print("Serving on %s:%s" % (addr, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="fire multi command")
|
||||
fire.Fire()
|
19
examples/frameworks/fire/fire_object_cmd.py
Normal file
19
examples/frameworks/fire/fire_object_cmd.py
Normal file
@ -0,0 +1,19 @@
|
||||
# ClearML - Example of Python Fire integration, with commands derived from an object
|
||||
#
|
||||
from clearml import Task
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
class Calculator(object):
|
||||
def add(self, x, y):
|
||||
return x + y
|
||||
|
||||
def multiply(self, x, y):
|
||||
return x * y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="fire object command")
|
||||
calculator = Calculator()
|
||||
fire.Fire(calculator)
|
15
examples/frameworks/fire/fire_single_cmd.py
Normal file
15
examples/frameworks/fire/fire_single_cmd.py
Normal file
@ -0,0 +1,15 @@
|
||||
# ClearML - Example of Python Fire integration, with a single command passed to Fire
|
||||
#
|
||||
from clearml import Task
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
def hello(count, name="clearml", prefix="prefix_", suffix="_suffix", **kwargs):
|
||||
for _ in range(count):
|
||||
print("Hello %s%s%s!" % (prefix, name, suffix))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="fire single command")
|
||||
fire.Fire(hello)
|
2
examples/frameworks/fire/requirements.txt
Normal file
2
examples/frameworks/fire/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
clearml
|
||||
fire
|
Loading…
Reference in New Issue
Block a user