mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add Python Fire support (#550)
This commit is contained in:
		
							parent
							
								
									9794135aed
								
							
						
					
					
						commit
						296cb7d899
					
				
							
								
								
									
										294
									
								
								clearml/binding/fire_bind.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								clearml/binding/fire_bind.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,294 @@ | ||||
| try: | ||||
|     import fire | ||||
|     import fire.core | ||||
|     import fire.helptext | ||||
| except ImportError: | ||||
|     fire = None | ||||
| 
 | ||||
| import inspect | ||||
| from types import SimpleNamespace | ||||
| from .frameworks import _patched_call  # noqa | ||||
| from ..config import get_remote_task_id, running_remotely | ||||
| from ..utilities.dicts import cast_str_to_bool | ||||
| 
 | ||||
| 
 | ||||
| class PatchFire: | ||||
|     _args = {} | ||||
|     _command_type = "fire.Command" | ||||
|     _command_arg_type_template = "fire.Arg@%s" | ||||
|     _shared_arg_type = "fire.Arg.shared" | ||||
|     _section_name = "Args" | ||||
|     _args_sep = "/" | ||||
|     _commands_sep = "." | ||||
|     _main_task = None | ||||
|     __remote_task_params = None | ||||
|     __remote_task_params_dict = {} | ||||
|     __patched = False | ||||
|     __groups = [] | ||||
|     __commands = {} | ||||
|     __default_args = SimpleNamespace( | ||||
|         completion=None, help=False, interactive=False, separator="-", trace=False, verbose=False | ||||
|     ) | ||||
|     __current_command = None | ||||
|     __fetched_current_command = False | ||||
|     __command_args = {} | ||||
| 
 | ||||
|     @classmethod | ||||
|     def patch(cls, task=None): | ||||
|         if fire is None: | ||||
|             return | ||||
| 
 | ||||
|         if task: | ||||
|             cls._main_task = task | ||||
|             cls._update_task_args() | ||||
| 
 | ||||
|         if not cls.__patched: | ||||
|             cls.__patched = True | ||||
|             if running_remotely(): | ||||
|                 fire.core._Fire = _patched_call(fire.core._Fire, PatchFire.__Fire) | ||||
|             else: | ||||
|                 fire.core._CallAndUpdateTrace = _patched_call( | ||||
|                     fire.core._CallAndUpdateTrace, PatchFire.__CallAndUpdateTrace | ||||
|                 ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _update_task_args(cls): | ||||
|         if running_remotely() or not cls._main_task: | ||||
|             return | ||||
|         args = {} | ||||
|         parameters_types = {} | ||||
|         if cls.__current_command is None: | ||||
|             args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} | ||||
|             parameters_types = {cls._section_name + cls._args_sep + k: cls._shared_arg_type for k in cls._args.keys()} | ||||
|             for k in (PatchFire.__command_args.get(None) or []): | ||||
|                 k = cls._section_name + cls._args_sep + k | ||||
|                 if k not in args: | ||||
|                     args[k] = None | ||||
|         else: | ||||
|             args[cls._section_name + cls._args_sep + cls.__current_command] = True | ||||
|             parameters_types[cls._section_name + cls._args_sep + cls.__current_command] = cls._command_type | ||||
|             args = { | ||||
|                 **args, | ||||
|                 **{ | ||||
|                     cls._section_name + cls._args_sep + cls.__current_command + cls._args_sep + k: v | ||||
|                     for k, v in cls._args.items() | ||||
|                     if k in (PatchFire.__command_args.get(cls.__current_command) or []) | ||||
|                 }, | ||||
|                 **{ | ||||
|                     cls._section_name + cls._args_sep + k: v | ||||
|                     for k, v in cls._args.items() | ||||
|                     if k not in (PatchFire.__command_args.get(cls.__current_command) or []) | ||||
|                 }, | ||||
|             } | ||||
|             parameters_types = { | ||||
|                 **parameters_types, | ||||
|                 **{ | ||||
|                     cls._section_name | ||||
|                     + cls._args_sep | ||||
|                     + cls.__current_command | ||||
|                     + cls._args_sep | ||||
|                     + k: cls._command_arg_type_template % cls.__current_command | ||||
|                     for k in cls._args.keys() | ||||
|                     if k in (PatchFire.__command_args.get(cls.__current_command) or []) | ||||
|                 }, | ||||
|                 **{ | ||||
|                     cls._section_name + cls._args_sep + k: cls._shared_arg_type | ||||
|                     for k in cls._args.keys() | ||||
|                     if k not in (PatchFire.__command_args.get(cls.__current_command) or []) | ||||
|                 }, | ||||
|             } | ||||
|         for command in cls.__commands: | ||||
|             if command == cls.__current_command: | ||||
|                 continue | ||||
|             args[cls._section_name + cls._args_sep + command] = False | ||||
|             parameters_types[cls._section_name + cls._args_sep + command] = cls._command_type | ||||
|             unused_command_args = { | ||||
|                 cls._section_name + cls._args_sep + command + cls._args_sep + k: None | ||||
|                 for k in (cls.__command_args.get(command) or []) | ||||
|             } | ||||
|             unused_paramenters_types = { | ||||
|                 cls._section_name | ||||
|                 + cls._args_sep | ||||
|                 + command | ||||
|                 + cls._args_sep | ||||
|                 + k: cls._command_arg_type_template % command | ||||
|                 for k in (cls.__command_args.get(command) or []) | ||||
|             } | ||||
|             args = {**args, **unused_command_args} | ||||
|             parameters_types = {**parameters_types, **unused_paramenters_types} | ||||
| 
 | ||||
|         # noinspection PyProtectedMember | ||||
|         cls._main_task._set_parameters( | ||||
|             args, | ||||
|             __update=True, | ||||
|             __parameters_types=parameters_types, | ||||
|         ) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs):  # noqa | ||||
|         if running_remotely(): | ||||
|             command = PatchFire._load_task_params() | ||||
|             if command is not None: | ||||
|                 replaced_args = command.split(PatchFire._commands_sep) | ||||
|             else: | ||||
|                 replaced_args = [] | ||||
|             for param in PatchFire.__remote_task_params[PatchFire._section_name].values(): | ||||
|                 if command is not None and param.type == PatchFire._command_arg_type_template % command: | ||||
|                     replaced_args.append("--" + param.name[len(command + PatchFire._args_sep) :]) | ||||
|                     value = PatchFire.__remote_task_params_dict[param.name] | ||||
|                     if len(value) > 0: | ||||
|                         replaced_args.append(value) | ||||
|                 if param.type == PatchFire._shared_arg_type:  | ||||
|                     replaced_args.append("--" + param.name) | ||||
|                     value = PatchFire.__remote_task_params_dict[param.name] | ||||
|                     if len(value) > 0: | ||||
|                         replaced_args.append(value) | ||||
|             return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs) | ||||
|         return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __CallAndUpdateTrace(  # noqa | ||||
|         original_fn, component, args_, component_trace, treatment, target, *args, **kwargs | ||||
|     ): | ||||
|         if running_remotely(): | ||||
|             return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs) | ||||
|         if not PatchFire.__fetched_current_command: | ||||
|             PatchFire.__fetched_current_command = True | ||||
|             context, component_context = PatchFire.__get_context_and_component(component) | ||||
|             PatchFire.__groups, PatchFire.__commands = PatchFire.__get_all_groups_and_commands( | ||||
|                 component_context, context | ||||
|             ) | ||||
|             PatchFire.__current_command = PatchFire.__get_current_command( | ||||
|                 args_, PatchFire.__groups, PatchFire.__commands | ||||
|             ) | ||||
|             for command in PatchFire.__commands: | ||||
|                 PatchFire.__command_args[command] = PatchFire.__get_command_args( | ||||
|                     component_context, command.split(PatchFire._commands_sep), PatchFire.__default_args, context | ||||
|                 ) | ||||
|             PatchFire.__command_args[None] = PatchFire.__get_command_args( | ||||
|                 component_context, | ||||
|                 "", | ||||
|                 PatchFire.__default_args, | ||||
|                 context, | ||||
|             ) | ||||
|         for k, v in PatchFire.__commands.items(): | ||||
|             if v == component: | ||||
|                 PatchFire.__current_command = k | ||||
|                 break | ||||
|             # Comparing methods in Python is equivalent to comparing the __func__ of the methods | ||||
|             # and the objects they are bound to. We do not care about the object in this case, | ||||
|             # so we just compare the __func__ | ||||
|             if inspect.ismethod(component) and inspect.ismethod(v) and v.__func__ == component.__func__: | ||||
|                 PatchFire.__current_command = k | ||||
|                 break | ||||
|         fn = component.__call__ if treatment == "callable" else component | ||||
|         metadata = fire.decorators.GetMetadata(component) | ||||
|         fn_spec = fire.inspectutils.GetFullArgSpec(component) | ||||
|         parse = fire.core._MakeParseFn(fn, metadata)  # noqa | ||||
|         (parsed_args, parsed_kwargs), _, _, _ = parse(args_) | ||||
|         PatchFire._args = {**PatchFire._args, **{k: v for k, v in zip(fn_spec.args, parsed_args)}, **parsed_kwargs} | ||||
|         PatchFire._update_task_args() | ||||
|         return original_fn(component, args_, component_trace, treatment, target, *args, **kwargs) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __get_context_and_component(component): | ||||
|         context = {} | ||||
|         component_context = component | ||||
|         # Walk through the stack to find the arguments with fire.Fire() has been called. | ||||
|         # Can't do it by patching the function because we want to patch _CallAndUpdateTrace, | ||||
|         # which is called by fire.Fire() | ||||
|         frame_infos = inspect.stack() | ||||
|         for frame_info_ind, frame_info in enumerate(frame_infos): | ||||
|             if frame_info.function == "Fire": | ||||
|                 component_context = inspect.getargvalues(frame_info.frame).locals["component"] | ||||
|                 if inspect.getargvalues(frame_info.frame).locals["component"] is None: | ||||
|                     # This is similar to how fire finds this context | ||||
|                     fire_context_frame = frame_infos[frame_info_ind + 1].frame | ||||
|                     context.update(fire_context_frame.f_globals) | ||||
|                     context.update(fire_context_frame.f_locals) | ||||
|                     # Ignore modules, as they yield too many commands. | ||||
|                     # Also ignore clearml.task. | ||||
|                     context = { | ||||
|                         k: v | ||||
|                         for k, v in context.items() | ||||
|                         if not inspect.ismodule(v) and (not inspect.isclass(v) or v.__module__ != "clearml.task") | ||||
|                     } | ||||
|                 break | ||||
|         return context, component_context | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __get_all_groups_and_commands(component, context): | ||||
|         groups = [] | ||||
|         commands = {} | ||||
|         component_trace_result = fire.core._Fire(component, [], PatchFire.__default_args, context).GetResult()  # noqa | ||||
|         group_args = [[]] | ||||
|         while len(group_args) > 0: | ||||
|             query_group = group_args[-1] | ||||
|             groups.append(PatchFire._commands_sep.join(query_group)) | ||||
|             group_args = group_args[:-1] | ||||
|             current_groups, current_commands = PatchFire.__get_groups_and_commands_for_args( | ||||
|                 component_trace_result, query_group, PatchFire.__default_args, context | ||||
|             ) | ||||
|             for command in current_commands: | ||||
|                 prefix = ( | ||||
|                     PatchFire._commands_sep.join(query_group) + PatchFire._commands_sep if len(query_group) > 0 else "" | ||||
|                 ) | ||||
|                 commands[prefix + command[0]] = command[1] | ||||
|             for group in current_groups: | ||||
|                 group_args.append(query_group + [group[0]]) | ||||
|         return groups, commands | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __get_groups_and_commands_for_args(component, args_, parsed_flag_args, context, name=None): | ||||
|         component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult()  # noqa | ||||
|         groups, commands, _, _ = fire.helptext._GetActionsGroupedByKind(component_trace, verbose=False)  # noqa | ||||
|         groups = [(name, member) for name, member in groups.GetItems()] | ||||
|         commands = [(name, member) for name, member in commands.GetItems()] | ||||
|         return groups, commands | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __get_current_command(args_, groups, commands): | ||||
|         current_command = "" | ||||
|         for arg in args_: | ||||
|             prefix = (current_command + PatchFire._commands_sep) if len(current_command) > 0 else "" | ||||
|             potential_current_command = prefix + arg | ||||
|             if potential_current_command not in groups: | ||||
|                 if potential_current_command in commands: | ||||
|                     return potential_current_command | ||||
|                 else: | ||||
|                     return None | ||||
|             current_command = potential_current_command | ||||
|         return None | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def __get_command_args(component, args_, parsed_flag_args, context, name=None): | ||||
|         component_trace = fire.core._Fire(component, args_, parsed_flag_args, context, name=name).GetResult()  # noqa | ||||
|         fn_spec = fire.inspectutils.GetFullArgSpec(component_trace) | ||||
|         return fn_spec.args | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _load_task_params(): | ||||
|         if not PatchFire.__remote_task_params: | ||||
|             from clearml import Task | ||||
| 
 | ||||
|             t = Task.get_task(task_id=get_remote_task_id()) | ||||
|             # noinspection PyProtectedMember | ||||
|             PatchFire.__remote_task_params = t._get_task_property("hyperparams") or {} | ||||
|             params_dict = t.get_parameters(backwards_compatibility=False) | ||||
|             skip = len(PatchFire._section_name) + 1 | ||||
|             PatchFire.__remote_task_params_dict = { | ||||
|                 k[skip:]: v | ||||
|                 for k, v in params_dict.items() | ||||
|                 if k.startswith(PatchFire._section_name + PatchFire._args_sep) | ||||
|             } | ||||
| 
 | ||||
|         command = [ | ||||
|             p.name | ||||
|             for p in PatchFire.__remote_task_params[PatchFire._section_name].values() | ||||
|             if p.type == PatchFire._command_type and cast_str_to_bool(p.value, strip=True) | ||||
|         ] | ||||
|         return command[0] if command else None | ||||
| 
 | ||||
| 
 | ||||
| # patch fire before anything | ||||
| PatchFire.patch() | ||||
| @ -50,6 +50,7 @@ from .binding.joblib_bind import PatchedJoblib | ||||
| from .binding.matplotlib_bind import PatchedMatplotlib | ||||
| from .binding.hydra_bind import PatchHydra | ||||
| from .binding.click_bind import PatchClick | ||||
| from .binding.fire_bind import PatchFire | ||||
| from .binding.jsonargs_bind import PatchJsonArgParse | ||||
| from .config import ( | ||||
|     config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI, | ||||
| @ -606,8 +607,9 @@ class Task(_Task): | ||||
| 
 | ||||
|                 # Patch ArgParser to be aware of the current task | ||||
|                 argparser_update_currenttask(Task.__main_task) | ||||
|                 # Patch Click | ||||
|                 # Patch Click and Fire | ||||
|                 PatchClick.patch(Task.__main_task) | ||||
|                 PatchFire.patch(Task.__main_task) | ||||
| 
 | ||||
|                 # set excluded arguments | ||||
|                 if isinstance(auto_connect_arg_parser, dict): | ||||
|  | ||||
							
								
								
									
										21
									
								
								examples/frameworks/fire/fire_class_cmd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								examples/frameworks/fire/fire_class_cmd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| # ClearML - Example of Python Fire integration, processing commands derived from a class | ||||
| # | ||||
| from clearml import Task | ||||
| 
 | ||||
| import fire | ||||
| 
 | ||||
| 
 | ||||
| class BrokenCalculator(object): | ||||
|     def __init__(self, offset=1): | ||||
|         self._offset = offset | ||||
| 
 | ||||
|     def add(self, x, y): | ||||
|         return x + y + self._offset | ||||
| 
 | ||||
|     def multiply(self, x, y): | ||||
|         return x * y + self._offset | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     Task.init(project_name="examples", task_name="fire class command") | ||||
|     fire.Fire(BrokenCalculator) | ||||
							
								
								
									
										23
									
								
								examples/frameworks/fire/fire_dict_cmd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								examples/frameworks/fire/fire_dict_cmd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | ||||
| # ClearML - Example of Python Fire integration, with commands derived from a dictionary  | ||||
| # | ||||
| from clearml import Task | ||||
| 
 | ||||
| import fire | ||||
| 
 | ||||
| 
 | ||||
| def add(x, y): | ||||
|     return x + y | ||||
| 
 | ||||
| 
 | ||||
| def multiply(x, y): | ||||
|     return x * y | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     Task.init(project_name="examples", task_name="fire dict command") | ||||
|     fire.Fire( | ||||
|         { | ||||
|             "add": add, | ||||
|             "multiply": multiply, | ||||
|         } | ||||
|     ) | ||||
							
								
								
									
										44
									
								
								examples/frameworks/fire/fire_grouping_cmd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								examples/frameworks/fire/fire_grouping_cmd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,44 @@ | ||||
| # ClearML - Example of Python Fire integration, with commands grouped inside classes | ||||
| # | ||||
| from clearml import Task | ||||
| 
 | ||||
| import fire | ||||
| 
 | ||||
| 
 | ||||
| class Other(object): | ||||
|     def status(self): | ||||
|         return "Other" | ||||
| 
 | ||||
| 
 | ||||
| class IngestionStage(object): | ||||
|     def __init__(self): | ||||
|         self.other = Other() | ||||
| 
 | ||||
|     def run(self): | ||||
|         return "Ingesting! Nom nom nom..." | ||||
| 
 | ||||
|     def hello(self, hello_str): | ||||
|         return hello_str | ||||
| 
 | ||||
| 
 | ||||
| class DigestionStage(object): | ||||
|     def run(self, volume=1): | ||||
|         return " ".join(["Burp!"] * volume) | ||||
| 
 | ||||
|     def status(self): | ||||
|         return "Satiated." | ||||
| 
 | ||||
| 
 | ||||
| class Pipeline(object): | ||||
|     def __init__(self): | ||||
|         self.ingestion = IngestionStage() | ||||
|         self.digestion = DigestionStage() | ||||
| 
 | ||||
|     def run(self): | ||||
|         self.ingestion.run() | ||||
|         self.digestion.run() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     Task.init(project_name="examples", task_name="fire grouping command") | ||||
|     fire.Fire(Pipeline) | ||||
							
								
								
									
										22
									
								
								examples/frameworks/fire/fire_multi_cmd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								examples/frameworks/fire/fire_multi_cmd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | ||||
| # ClearML - Example of Python Fire integration, processing multiple commands, when fire is initialized with no component | ||||
| # | ||||
| from clearml import Task | ||||
| 
 | ||||
| import fire | ||||
| 
 | ||||
| 
 | ||||
| def hello(count, name="clearml", prefix="prefix_", suffix="_suffix", **kwargs): | ||||
|     for _ in range(count): | ||||
|         print("Hello %s%s%s!" % (prefix, name, suffix)) | ||||
| 
 | ||||
| 
 | ||||
| def serve(addr, port, should_serve=False): | ||||
|     if not should_serve: | ||||
|         print("Not serving") | ||||
|     else: | ||||
|         print("Serving on %s:%s" % (addr, port)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     Task.init(project_name="examples", task_name="fire multi command") | ||||
|     fire.Fire() | ||||
							
								
								
									
										19
									
								
								examples/frameworks/fire/fire_object_cmd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								examples/frameworks/fire/fire_object_cmd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| # ClearML - Example of Python Fire integration, with commands derived from an object | ||||
| # | ||||
| from clearml import Task | ||||
| 
 | ||||
| import fire | ||||
| 
 | ||||
| 
 | ||||
| class Calculator(object): | ||||
|     def add(self, x, y): | ||||
|         return x + y | ||||
| 
 | ||||
|     def multiply(self, x, y): | ||||
|         return x * y | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     Task.init(project_name="examples", task_name="fire object command") | ||||
|     calculator = Calculator() | ||||
|     fire.Fire(calculator) | ||||
							
								
								
									
										15
									
								
								examples/frameworks/fire/fire_single_cmd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								examples/frameworks/fire/fire_single_cmd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | ||||
| # ClearML - Example of Python Fire integration, with a single command passed to Fire | ||||
| # | ||||
| from clearml import Task | ||||
| 
 | ||||
| import fire | ||||
| 
 | ||||
| 
 | ||||
| def hello(count, name="clearml", prefix="prefix_", suffix="_suffix", **kwargs): | ||||
|     for _ in range(count): | ||||
|         print("Hello %s%s%s!" % (prefix, name, suffix)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     Task.init(project_name="examples", task_name="fire single command") | ||||
|     fire.Fire(hello) | ||||
							
								
								
									
										2
									
								
								examples/frameworks/fire/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								examples/frameworks/fire/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| clearml | ||||
| fire | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user