mirror of
https://github.com/clearml/clearml
synced 2025-04-27 01:39:17 +00:00
Add Python Fire support (#550)
This commit is contained in:
parent
9794135aed
commit
296cb7d899
294
clearml/binding/fire_bind.py
Normal file
294
clearml/binding/fire_bind.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
try:
|
||||||
|
import fire
|
||||||
|
import fire.core
|
||||||
|
import fire.helptext
|
||||||
|
except ImportError:
|
||||||
|
fire = None
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from .frameworks import _patched_call # noqa
|
||||||
|
from ..config import get_remote_task_id, running_remotely
|
||||||
|
from ..utilities.dicts import cast_str_to_bool
|
||||||
|
|
||||||
|
|
||||||
|
class PatchFire:
|
||||||
|
_args = {}
|
||||||
|
_command_type = "fire.Command"
|
||||||
|
_command_arg_type_template = "fire.Arg@%s"
|
||||||
|
_shared_arg_type = "fire.Arg.shared"
|
||||||
|
_section_name = "Args"
|
||||||
|
_args_sep = "/"
|
||||||
|
_commands_sep = "."
|
||||||
|
_main_task = None
|
||||||
|
__remote_task_params = None
|
||||||
|
__remote_task_params_dict = {}
|
||||||
|
__patched = False
|
||||||
|
__groups = []
|
||||||
|
__commands = {}
|
||||||
|
__default_args = SimpleNamespace(
|
||||||
|
completion=None, help=False, interactive=False, separator="-", trace=False, verbose=False
|
||||||
|
)
|
||||||
|
__current_command = None
|
||||||
|
__fetched_current_command = False
|
||||||
|
__command_args = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch(cls, task=None):
|
||||||
|
if fire is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if task:
|
||||||
|
cls._main_task = task
|
||||||
|
cls._update_task_args()
|
||||||
|
|
||||||
|
if not cls.__patched:
|
||||||
|
cls.__patched = True
|
||||||
|
if running_remotely():
|
||||||
|
fire.core._Fire = _patched_call(fire.core._Fire, PatchFire.__Fire)
|
||||||
|
else:
|
||||||
|
fire.core._CallAndUpdateTrace = _patched_call(
|
||||||
|
fire.core._CallAndUpdateTrace, PatchFire.__CallAndUpdateTrace
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _update_task_args(cls):
|
||||||
|
if running_remotely() or not cls._main_task:
|
||||||
|
return
|
||||||
|
args = {}
|
||||||
|
parameters_types = {}
|
||||||
|
if cls.__current_command is None:
|
||||||
|
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
||||||
|
parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()}
|
||||||
|
for k in (PatchFire.__command_args.get(None) or []):
|
||||||
|
k = cls._section_name + cls._args_sep + k
|
||||||
|
if k not in args:
|
||||||
|
args[k] = None
|
||||||
|
else:
|
||||||
|
args[cls._section_name + cls._args_sep + cls.__current_command] = True
|
||||||
|
parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type
|
||||||
|
args = {
|
||||||
|
**args,
|
||||||
|
**{
|
||||||
|
cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v
|
||||||
|
for k, v in cls._args.items()
|
||||||
|
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
cls._section_name + cls._args_sep + k: v
|
||||||
|
for k, v in cls._args.items()
|
||||||
|
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parameters_types = {
|
||||||
|
**parameters_types,
|
||||||
|
**{
|
||||||
|
cls._section_name
|
||||||
|
+ cls._args_sep
|
||||||
|
+ cls.__current_command
|
||||||
|
+ cls._args_sep
|
||||||
|
+ k: cls._command_arg_type_template % cls.__current_command
|
||||||
|
for k in cls._args.keys()
|
||||||
|
if k in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
cls._section_name + cls._args_sep + k: cls._shared_arg_type
|
||||||
|
for k in cls._args.keys()
|
||||||
|
if k not in (PatchFire.__command_args.get(cls.__current_command) or [])
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for command in cls.__commands:
|
||||||
|
if command == cls.__current_command:
|
||||||
|
continue
|
||||||
|
args[cls._section_name + cls._args_sep + command] = False
|
||||||
|
parameters_types[cls._section_name + cls._args_sep + command] = cls._command_type
|
||||||
|
unused_command_args = {
|
||||||
|
cls._section_name + cls._args_sep + command + cls._args_sep + k: None
|
||||||
|
for k in (cls.__command_args.get(command) or [])
|
||||||
|
}
|
||||||
|
unused_paramenters_types = {
|
||||||
|
cls._section_name
|
||||||
|
+ cls._args_sep
|
||||||
|
+ command
|
||||||
|
+ cls._args_sep
|
||||||
|
+ k: cls._command_arg_type_template % command
|
||||||
|
for k in (cls.__command_args.get(command) or [])
|
||||||
|
}
|
||||||
|
args = {**args, **unused_command_args}
|
||||||
|
parameters_types = {**parameters_types, **unused_paramenters_types}
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
cls._main_task._set_parameters(
|
||||||
|
args,
|
||||||
|
__update=True,
|
||||||
|
__parameters_types=parameters_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
|
||||||
|
if running_remotely():
|
||||||
|
command = PatchFire._load_task_params()
|
||||||
|
if command is not None:
|
||||||
|
replaced_args = command.split(PatchFire._commands_sep)
|
||||||
|
else:
|
||||||
|
replaced_args = []
|
||||||
|
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
||||||
|
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
||||||
|
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep) :])
|
||||||
|
value = PatchFire.__remote_task_params_dict[param.name]
|
||||||
|
if len(value) > 0:
|
||||||
|
replaced_args.append(value)
|
||||||
|
if param.type == PatchFire._shared_arg_type:
|
||||||
|
replaced_args.append("--" + param.name)
|
||||||
|
value = PatchFire.__remote_task_params_dict[param.name]
|
||||||
|
if len(value) > 0:
|
||||||
|
replaced_args.append(value)
|
||||||
|
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
|
||||||
|
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __CallAndUpdateTrace( # noqa
|
||||||
|
original_fn, component, args_, component_trace, treatment, target, *args, **kwargs
|
||||||
|
):
|
||||||
|
if running_remotely():
|
||||||
|
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)
|
||||||
|
if not PatchFire.__fetched_current_command:
|
||||||
|
PatchFire.__fetched_current_command = True
|
||||||
|
context, component_context = PatchFire.__get_context_and_component(component)
|
||||||
|
PatchFire.__groups, PatchFire.__commands = PatchFire.__get_all_groups_and_commands(
|
||||||
|
component_context, context
|
||||||
|
)
|
||||||
|
PatchFire.__current_command = PatchFire.__get_current_command(
|
||||||
|
args_, PatchFire.__groups, PatchFire.__commands
|
||||||
|
)
|
||||||
|
for command in PatchFire.__commands:
|
||||||
|
PatchFire.__command_args[command] = PatchFire.__get_command_args(
|
||||||
|
component_context, command.split(PatchFire._commands_sep), PatchFire.__default_args, context
|
||||||
|
)
|
||||||
|
PatchFire.__command_args[None] = PatchFire.__get_command_args(
|
||||||
|
component_context,
|
||||||
|
"",
|
||||||
|
PatchFire.__default_args,
|
||||||
|
context,
|
||||||
|
)
|
||||||
|
for k, v in PatchFire.__commands.items():
|
||||||
|
if v == component:
|
||||||
|
PatchFire.__current_command = k
|
||||||
|
break
|
||||||
|
# Comparing methods in Python is equivalent to comparing the __func__ of the methods
|
||||||
|
# and the objects they are bound to. We do not care about the object in this case,
|
||||||
|
# so we just compare the __func__
|
||||||
|
if inspect.ismethod(component) and inspect.ismethod(v) and v.__func__ == component.__func__:
|
||||||
|
PatchFire.__current_command = k
|
||||||
|
break
|
||||||
|
fn = component.__call__ if treatment == "callable" else component
|
||||||
|
metadata = fire.decorators.GetMetadata(component)
|
||||||
|
fn_spec = fire.inspectutils.GetFullArgSpec(component)
|
||||||
|
parse = fire.core._MakeParseFn(fn, metadata) # noqa
|
||||||
|
(parsed_args, parsed_kwargs), _, _, _ = parse(args_)
|
||||||
|
PatchFire._args = {**PatchFire._args, **{k: v for k, v in zip(fn_spec.args, parsed_args)}, **parsed_kwargs}
|
||||||
|
PatchFire._update_task_args()
|
||||||
|
return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_context_and_component(component):
|
||||||
|
context = {}
|
||||||
|
component_context = component
|
||||||
|
# Walk through the stack to find the arguments with fire.Fire() has been called.
|
||||||
|
# Can't do it by patching the function because we want to patch _CallAndUpdateTrace,
|
||||||
|
# which is called by fire.Fire()
|
||||||
|
frame_infos = inspect.stack()
|
||||||
|
for frame_info_ind, frame_info in enumerate(frame_infos):
|
||||||
|
if frame_info.function == "Fire":
|
||||||
|
component_context = inspect.getargvalues(frame_info.frame).locals["component"]
|
||||||
|
if inspect.getargvalues(frame_info.frame).locals["component"] is None:
|
||||||
|
# This is similar to how fire finds this context
|
||||||
|
fire_context_frame = frame_infos[frame_info_ind + 1].frame
|
||||||
|
context.update(fire_context_frame.f_globals)
|
||||||
|
context.update(fire_context_frame.f_locals)
|
||||||
|
# Ignore modules, as they yield too many commands.
|
||||||
|
# Also ignore clearml.task.
|
||||||
|
context = {
|
||||||
|
k: v
|
||||||
|
for k, v in context.items()
|
||||||
|
if not inspect.ismodule(v) and (not inspect.isclass(v) or v.__module__ != "clearml.task")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
return context, component_context
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_all_groups_and_commands(component, context):
|
||||||
|
groups = []
|
||||||
|
commands = {}
|
||||||
|
component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult() # noqa
|
||||||
|
group_args = [[]]
|
||||||
|
while len(group_args) > 0:
|
||||||
|
query_group = group_args[-1]
|
||||||
|
groups.append(PatchFire._commands_sep.join(query_group))
|
||||||
|
group_args = group_args[:-1]
|
||||||
|
current_groups, current_commands = PatchFire.__get_groups_and_commands_for_args(
|
||||||
|
component_trace_result, query_group, PatchFire.__default_args, context
|
||||||
|
)
|
||||||
|
for command in current_commands:
|
||||||
|
prefix = (
|
||||||
|
PatchFire._commands_sep.join(query_group) + PatchFire._commands_sep if len(query_group) > 0 else ""
|
||||||
|
)
|
||||||
|
commands[prefix + command[0]] = command[1]
|
||||||
|
for group in current_groups:
|
||||||
|
group_args.append(query_group + [group[0]])
|
||||||
|
return groups, commands
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None):
|
||||||
|
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||||
|
groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False) # noqa
|
||||||
|
groups = [(name, member) for name, member in groups.GetItems()]
|
||||||
|
commands = [(name, member) for name, member in commands.GetItems()]
|
||||||
|
return groups, commands
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_current_command(args_, groups, commands):
|
||||||
|
current_command = ""
|
||||||
|
for arg in args_:
|
||||||
|
prefix = (current_command + PatchFire._commands_sep) if len(current_command) > 0 else ""
|
||||||
|
potential_current_command = prefix + arg
|
||||||
|
if potential_current_command not in groups:
|
||||||
|
if potential_current_command in commands:
|
||||||
|
return potential_current_command
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
current_command = potential_current_command
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_command_args(component, args_, parsed_flag_args, context, name=None):
|
||||||
|
component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult() # noqa
|
||||||
|
fn_spec = fire.inspectutils.GetFullArgSpec(component_trace)
|
||||||
|
return fn_spec.args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_task_params():
|
||||||
|
if not PatchFire.__remote_task_params:
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
t = Task.get_task(task_id=get_remote_task_id())
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
PatchFire.__remote_task_params = t._get_task_property("hyperparams") or {}
|
||||||
|
params_dict = t.get_parameters(backwards_compatibility=False)
|
||||||
|
skip = len(PatchFire._section_name) + 1
|
||||||
|
PatchFire.__remote_task_params_dict = {
|
||||||
|
k[skip:]: v
|
||||||
|
for k, v in params_dict.items()
|
||||||
|
if k.startswith(PatchFire._section_name + PatchFire._args_sep)
|
||||||
|
}
|
||||||
|
|
||||||
|
command = [
|
||||||
|
p.name
|
||||||
|
for p in PatchFire.__remote_task_params[PatchFire._section_name].values()
|
||||||
|
if p.type == PatchFire._command_type and cast_str_to_bool(p.value, strip=True)
|
||||||
|
]
|
||||||
|
return command[0] if command else None
|
||||||
|
|
||||||
|
|
||||||
|
# patch fire before anything
|
||||||
|
PatchFire.patch()
|
@ -50,6 +50,7 @@ from .binding.joblib_bind import PatchedJoblib
|
|||||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||||
from .binding.hydra_bind import PatchHydra
|
from .binding.hydra_bind import PatchHydra
|
||||||
from .binding.click_bind import PatchClick
|
from .binding.click_bind import PatchClick
|
||||||
|
from .binding.fire_bind import PatchFire
|
||||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||||
from .config import (
|
from .config import (
|
||||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||||
@ -606,8 +607,9 @@ class Task(_Task):
|
|||||||
|
|
||||||
# Patch ArgParser to be aware of the current task
|
# Patch ArgParser to be aware of the current task
|
||||||
argparser_update_currenttask(Task.__main_task)
|
argparser_update_currenttask(Task.__main_task)
|
||||||
# Patch Click
|
# Patch Click and Fire
|
||||||
PatchClick.patch(Task.__main_task)
|
PatchClick.patch(Task.__main_task)
|
||||||
|
PatchFire.patch(Task.__main_task)
|
||||||
|
|
||||||
# set excluded arguments
|
# set excluded arguments
|
||||||
if isinstance(auto_connect_arg_parser, dict):
|
if isinstance(auto_connect_arg_parser, dict):
|
||||||
|
21
examples/frameworks/fire/fire_class_cmd.py
Normal file
21
examples/frameworks/fire/fire_class_cmd.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# ClearML - Example of Python Fire integration, processing commands derived from a class
|
||||||
|
#
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
class BrokenCalculator(object):
|
||||||
|
def __init__(self, offset=1):
|
||||||
|
self._offset = offset
|
||||||
|
|
||||||
|
def add(self, x, y):
|
||||||
|
return x + y + self._offset
|
||||||
|
|
||||||
|
def multiply(self, x, y):
|
||||||
|
return x * y + self._offset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="fire class command")
|
||||||
|
fire.Fire(BrokenCalculator)
|
23
examples/frameworks/fire/fire_dict_cmd.py
Normal file
23
examples/frameworks/fire/fire_dict_cmd.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# ClearML - Example of Python Fire integration, with commands derived from a dictionary
|
||||||
|
#
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
def add(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
def multiply(x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="fire dict command")
|
||||||
|
fire.Fire(
|
||||||
|
{
|
||||||
|
"add": add,
|
||||||
|
"multiply": multiply,
|
||||||
|
}
|
||||||
|
)
|
44
examples/frameworks/fire/fire_grouping_cmd.py
Normal file
44
examples/frameworks/fire/fire_grouping_cmd.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# ClearML - Example of Python Fire integration, with commands grouped inside classes
|
||||||
|
#
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
class Other(object):
|
||||||
|
def status(self):
|
||||||
|
return "Other"
|
||||||
|
|
||||||
|
|
||||||
|
class IngestionStage(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.other = Other()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
return "Ingesting! Nom nom nom..."
|
||||||
|
|
||||||
|
def hello(self, hello_str):
|
||||||
|
return hello_str
|
||||||
|
|
||||||
|
|
||||||
|
class DigestionStage(object):
|
||||||
|
def run(self, volume=1):
|
||||||
|
return " ".join(["Burp!"] * volume)
|
||||||
|
|
||||||
|
def status(self):
|
||||||
|
return "Satiated."
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.ingestion = IngestionStage()
|
||||||
|
self.digestion = DigestionStage()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.ingestion.run()
|
||||||
|
self.digestion.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="fire grouping command")
|
||||||
|
fire.Fire(Pipeline)
|
22
examples/frameworks/fire/fire_multi_cmd.py
Normal file
22
examples/frameworks/fire/fire_multi_cmd.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# ClearML - Example of Python Fire integration, processing multiple commands, when fire is initialized with no component
|
||||||
|
#
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
def hello(count, name="clearml", prefix="prefix_", suffix="_suffix", **kwargs):
|
||||||
|
for _ in range(count):
|
||||||
|
print("Hello %s%s%s!" % (prefix, name, suffix))
|
||||||
|
|
||||||
|
|
||||||
|
def serve(addr, port, should_serve=False):
|
||||||
|
if not should_serve:
|
||||||
|
print("Not serving")
|
||||||
|
else:
|
||||||
|
print("Serving on %s:%s" % (addr, port))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="fire multi command")
|
||||||
|
fire.Fire()
|
19
examples/frameworks/fire/fire_object_cmd.py
Normal file
19
examples/frameworks/fire/fire_object_cmd.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# ClearML - Example of Python Fire integration, with commands derived from an object
|
||||||
|
#
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
class Calculator(object):
|
||||||
|
def add(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def multiply(self, x, y):
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="fire object command")
|
||||||
|
calculator = Calculator()
|
||||||
|
fire.Fire(calculator)
|
15
examples/frameworks/fire/fire_single_cmd.py
Normal file
15
examples/frameworks/fire/fire_single_cmd.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# ClearML - Example of Python Fire integration, with a single command passed to Fire
|
||||||
|
#
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import fire
|
||||||
|
|
||||||
|
|
||||||
|
def hello(count, name="clearml", prefix="prefix_", suffix="_suffix", **kwargs):
|
||||||
|
for _ in range(count):
|
||||||
|
print("Hello %s%s%s!" % (prefix, name, suffix))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="fire single command")
|
||||||
|
fire.Fire(hello)
|
2
examples/frameworks/fire/requirements.txt
Normal file
2
examples/frameworks/fire/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
clearml
|
||||||
|
fire
|
Loading…
Reference in New Issue
Block a user