From 4bd6b20a62d5e97b89d4807211d55a38b623d944 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 3 Jan 2024 22:18:39 +0200 Subject: [PATCH] Add support for decorated pipeline steps (#1154) --- clearml/automation/controller.py | 34 +++-- clearml/backend_interface/task/populate.py | 121 ++++++++++++++++-- .../decorated_pipeline_step_decorators.py | 24 ++++ .../decorated_pipeline_step_functions.py | 27 ++++ 4 files changed, 185 insertions(+), 21 deletions(-) create mode 100644 examples/pipeline/decorated_pipeline_step_decorators.py create mode 100644 examples/pipeline/decorated_pipeline_step_functions.py diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index cd6368ef..f35cf0d7 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -177,7 +177,8 @@ class PipelineController(object): always_create_from_code=True, # type: bool artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]] - output_uri=None # type: Optional[Union[str, bool]] + output_uri=None, # type: Optional[Union[str, bool]] + skip_global_imports=False # type: bool ): # type: (...) -> None """ @@ -267,6 +268,9 @@ class PipelineController(object): :param output_uri: The storage / output url for this pipeline. This is the default location for output models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter). The `output_uri` of this pipeline's steps will default to this value. + :param skip_global_imports: If True, global imports will not be included in the steps' execution when creating + the steps from a functions, otherwise all global imports will be automatically imported in a safe manner at + the beginning of each step’s execution. Default is False """ if auto_version_bump is not None: warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning) @@ -303,6 +307,7 @@ class PipelineController(object): self._last_progress_update_time = 0 self._artifact_serialization_function = artifact_serialization_function self._artifact_deserialization_function = artifact_deserialization_function + self._skip_global_imports = skip_global_imports if not self._task: task_name = name or project or '{}'.format(datetime.now()) if self._pipeline_as_sub_project: @@ -1521,7 +1526,8 @@ class PipelineController(object): dry_run=True, task_template_header=self._task_template_header, artifact_serialization_function=self._artifact_serialization_function, - artifact_deserialization_function=self._artifact_deserialization_function + artifact_deserialization_function=self._artifact_deserialization_function, + skip_global_imports=self._skip_global_imports ) return task_definition @@ -3316,7 +3322,8 @@ class PipelineDecorator(PipelineController): repo_commit=None, # type: Optional[str] artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]] - output_uri=None # type: Optional[Union[str, bool]] + output_uri=None, # type: Optional[Union[str, bool]] + skip_global_imports=False # type: bool ): # type: (...) -> () """ @@ -3399,6 +3406,9 @@ class PipelineDecorator(PipelineController): :param output_uri: The storage / output url for this pipeline. This is the default location for output models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter). The `output_uri` of this pipeline's steps will default to this value. + :param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all + global imports will be automatically imported in a safe manner at the beginning of each step’s execution. + Default is False """ super(PipelineDecorator, self).__init__( name=name, @@ -3420,7 +3430,8 @@ class PipelineDecorator(PipelineController): always_create_from_code=False, artifact_serialization_function=artifact_serialization_function, artifact_deserialization_function=artifact_deserialization_function, - output_uri=output_uri + output_uri=output_uri, + skip_global_imports=skip_global_imports ) # if we are in eager execution, make sure parent class knows it @@ -3686,7 +3697,8 @@ class PipelineDecorator(PipelineController): task_template_header=self._task_template_header, _sanitize_function=sanitize, artifact_serialization_function=self._artifact_serialization_function, - artifact_deserialization_function=self._artifact_deserialization_function + artifact_deserialization_function=self._artifact_deserialization_function, + skip_global_imports=self._skip_global_imports ) return task_definition @@ -4173,7 +4185,8 @@ class PipelineDecorator(PipelineController): repo_commit=None, # type: Optional[str] artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]] - output_uri=None # type: Optional[Union[str, bool]] + output_uri=None, # type: Optional[Union[str, bool]] + skip_global_imports=False # type: bool ): # type: (...) -> Callable """ @@ -4287,6 +4300,9 @@ class PipelineDecorator(PipelineController): :param output_uri: The storage / output url for this pipeline. This is the default location for output models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter). The `output_uri` of this pipeline's steps will default to this value. + :param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all + global imports will be automatically imported in a safe manner at the beginning of each step’s execution. + Default is False """ def decorator_wrap(func): @@ -4333,7 +4349,8 @@ class PipelineDecorator(PipelineController): repo_commit=repo_commit, artifact_serialization_function=artifact_serialization_function, artifact_deserialization_function=artifact_deserialization_function, - output_uri=output_uri + output_uri=output_uri, + skip_global_imports=skip_global_imports ) ret_val = func(**pipeline_kwargs) LazyEvalWrapper.trigger_all_remote_references() @@ -4385,7 +4402,8 @@ class PipelineDecorator(PipelineController): repo_commit=repo_commit, artifact_serialization_function=artifact_serialization_function, artifact_deserialization_function=artifact_deserialization_function, - output_uri=output_uri + output_uri=output_uri, + skip_global_imports=skip_global_imports ) a_pipeline._args_map = args_map or {} diff --git a/clearml/backend_interface/task/populate.py b/clearml/backend_interface/task/populate.py index 0f1fe53c..9b747efb 100644 --- a/clearml/backend_interface/task/populate.py +++ b/clearml/backend_interface/task/populate.py @@ -24,7 +24,7 @@ class CreateAndPopulate(object): ":" \ "(?P{regular}.*)?" \ "$" \ - .format( + .format( regular=r"[^/@:#]" ) @@ -199,9 +199,9 @@ class CreateAndPopulate(object): # if there is nothing to populate, return if not any([ - self.folder, self.commit, self.branch, self.repo, self.script, self.cwd, - self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values())) - ): + self.folder, self.commit, self.branch, self.repo, self.script, self.cwd, + self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values())) + ): return task # clear the script section @@ -219,7 +219,7 @@ class CreateAndPopulate(object): if self.cwd: self.cwd = self.cwd cwd = self.cwd if Path(self.cwd).is_dir() else ( - Path(repo_info.script['repo_root']) / self.cwd).as_posix() + Path(repo_info.script['repo_root']) / self.cwd).as_posix() if not Path(cwd).is_dir(): raise ValueError("Working directory \'{}\' could not be found".format(cwd)) cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix() @@ -577,6 +577,7 @@ if __name__ == '__main__': artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]] _sanitize_function=None, # type: Optional[Callable[[str], str]] _sanitize_helper_functions=None, # type: Optional[Callable[[str], str]] + skip_global_imports=False # type: bool ): # type: (...) -> Optional[Dict, Task] """ @@ -659,6 +660,9 @@ if __name__ == '__main__': return dill.loads(bytes_) :param _sanitize_function: Sanitization function for the function string. :param _sanitize_helper_functions: Sanitization function for the helper function string. + :param skip_global_imports: If True, the global imports will not be fetched from the function's file, otherwise + all global imports will be automatically imported in a safe manner at the beginning of the function's + execution. Default is False :return: Newly created Task object """ # not set -> equals True @@ -671,7 +675,7 @@ if __name__ == '__main__': assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict))) function_source, function_name = CreateFromFunction.__extract_function_information( - a_function, sanitize_function=_sanitize_function + a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports ) # add helper functions on top. for f in (helper_functions or []): @@ -846,11 +850,102 @@ if __name__ == '__main__': return function_source @staticmethod - def __extract_function_information(function, sanitize_function=None): - # type: (Callable, Optional[Callable]) -> (str, str) - function_name = str(function.__name__) - function_source = inspect.getsource(function) + def __extract_imports(func): + def add_import_guard(import_): + return ("try:\n " + + import_.replace("\n", "\n ", import_.count("\n") - 1) + + "except Exception as e:\n print('Import error: ' + str(e))\n" + ) + + # noinspection PyBroadException + try: + import ast + func_module = inspect.getmodule(func) + source = inspect.getsource(func_module) + source_lines = inspect.getsourcelines(func_module)[0] + parsed_source = ast.parse(source) + imports = [] + for parsed_source_entry in parsed_source.body: + if isinstance(parsed_source_entry, + (ast.Import, ast.ImportFrom)) and parsed_source_entry.col_offset == 0: + imports.append( + "\n".join(source_lines[parsed_source_entry.lineno - 1: parsed_source_entry.end_lineno])) + imports = [add_import_guard(import_) for import_ in imports] + return "\n".join(imports) + except Exception as e: + getLogger().warning('Could not fetch function imports: {}'.format(e)) + return "" + + @staticmethod + def __extract_wrapped(decorated): + if not decorated.__closure__: + return None + closure = (c.cell_contents for c in decorated.__closure__) + if not closure: + return None + return next((c for c in closure if inspect.isfunction(c)), None) + + @staticmethod + def __sanitize(func_source, sanitize_function=None): if sanitize_function: - function_source = sanitize_function(function_source) - function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source) - return function_source, function_name + func_source = sanitize_function(func_source) + return CreateFromFunction.__sanitize_remove_type_hints(func_source) + + @staticmethod + def __get_func_members(module): + result = [] + try: + import ast + + source = inspect.getsource(module) + parsed = ast.parse(source) + for f in parsed.body: + if isinstance(f, ast.FunctionDef): + result.append(f.name) + except Exception as e: + name = getattr(module, "__name__", module) + getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e)) + return result + + @staticmethod + def __get_source_with_decorators(func, original_module=None, sanitize_function=None): + if original_module is None: + original_module = inspect.getmodule(func) + func_members = CreateFromFunction.__get_func_members(original_module) + try: + func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction)) + except Exception as e: + name = getattr(original_module, "__name__", original_module) + getLogger().warning('Could not fetch functions from {}: {}'.format(name, e)) + func_members_dict = {} + decorated_func = CreateFromFunction.__extract_wrapped(func) + if not decorated_func: + return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function) + decorated_func_source = CreateFromFunction.__sanitize( + inspect.getsource(decorated_func), + sanitize_function=sanitize_function + ) + for decorator in decorated_func_source.split("\n"): + if decorator.startswith("@"): + decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)] + decorator_func = func_members_dict.get(decorator_func_name) + if decorator_func_name not in func_members or not decorator_func: + continue + decorated_func_source = CreateFromFunction.__get_source_with_decorators( + decorator_func, + original_module=original_module, + sanitize_function=sanitize_function + ) + "\n\n" + decorated_func_source + return decorated_func_source + + @staticmethod + def __extract_function_information(function, sanitize_function=None, skip_global_imports=False): + # type: (Callable, Optional[Callable], bool) -> (str, str) + decorated = CreateFromFunction.__extract_wrapped(function) + function_name = str(decorated.__name__) if decorated else str(function.__name__) + function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function) + if not skip_global_imports: + imports = CreateFromFunction.__extract_imports(decorated if decorated else function) + else: + imports = "" + return imports + "\n" + function_source, function_name diff --git a/examples/pipeline/decorated_pipeline_step_decorators.py b/examples/pipeline/decorated_pipeline_step_decorators.py new file mode 100644 index 00000000..dfff0260 --- /dev/null +++ b/examples/pipeline/decorated_pipeline_step_decorators.py @@ -0,0 +1,24 @@ +from clearml import PipelineDecorator + + +def our_decorator(func): + def function_wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + return function_wrapper + + +@PipelineDecorator.component() +@our_decorator +def step(): + return 1 + + +@PipelineDecorator.pipeline(name="test_decorated", project="test_decorated") +def pipeline(): + result = step() + assert result == 2 + + +if __name__ == "__main__": + PipelineDecorator.run_locally() + pipeline() diff --git a/examples/pipeline/decorated_pipeline_step_functions.py b/examples/pipeline/decorated_pipeline_step_functions.py new file mode 100644 index 00000000..163080eb --- /dev/null +++ b/examples/pipeline/decorated_pipeline_step_functions.py @@ -0,0 +1,27 @@ +from clearml import PipelineController + + +def our_decorator(func): + def function_wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + return function_wrapper + + +@our_decorator +def step(): + return 1 + + +def evaluate(step_return): + assert step_return == 2 + + +if __name__ == "__main__": + pipeline = PipelineController(name="test_decorated", project="test_decorated") + pipeline.add_function_step(name="step", function=step, function_return=["step_return"]) + pipeline.add_function_step( + name="evaluate", + function=evaluate, + function_kwargs=dict(step_return='${step.step_return}') + ) + pipeline.start_locally(run_pipeline_steps_locally=True)