From d993fa75e4b4156eb7891ed33c31d2f0deccba0f Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 6 Jan 2024 12:35:06 +0200 Subject: [PATCH] Fix decorated pipeline steps --- clearml/automation/controller.py | 6 ++- clearml/backend_interface/task/populate.py | 57 +++++++++++++++------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index f35cf0d7..8ed5aa17 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -3919,10 +3919,12 @@ class PipelineDecorator(PipelineController): :return: function wrapper """ def decorator_wrap(func): - _name = name or str(func.__name__) + # noinspection PyProtectedMember + unwrapped_func = CreateFromFunction._deep_extract_wrapped(func) + _name = name or str(unwrapped_func.__name__) function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values] - inspect_func = inspect.getfullargspec(func) + inspect_func = inspect.getfullargspec(unwrapped_func) # add default argument values if inspect_func.args: default_values = list(inspect_func.defaults or []) diff --git a/clearml/backend_interface/task/populate.py b/clearml/backend_interface/task/populate.py index 9b747efb..ded68842 100644 --- a/clearml/backend_interface/task/populate.py +++ b/clearml/backend_interface/task/populate.py @@ -877,7 +877,7 @@ if __name__ == '__main__': return "" @staticmethod - def __extract_wrapped(decorated): + def _extract_wrapped(decorated): if not decorated.__closure__: return None closure = (c.cell_contents for c in decorated.__closure__) @@ -885,6 +885,15 @@ if __name__ == '__main__': return None return next((c for c in closure if inspect.isfunction(c)), None) + @staticmethod + def _deep_extract_wrapped(decorated): + while True: + # noinspection PyProtectedMember + func = CreateFromFunction._extract_wrapped(decorated) + if not func: + return decorated + decorated = func + @staticmethod def __sanitize(func_source, sanitize_function=None): if sanitize_function: @@ -918,34 +927,46 @@ if __name__ == '__main__': name = getattr(original_module, "__name__", original_module) getLogger().warning('Could not fetch functions from {}: {}'.format(name, e)) func_members_dict = {} - decorated_func = CreateFromFunction.__extract_wrapped(func) - if not decorated_func: - return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function) + decorated_func = CreateFromFunction._deep_extract_wrapped(func) decorated_func_source = CreateFromFunction.__sanitize( inspect.getsource(decorated_func), sanitize_function=sanitize_function ) - for decorator in decorated_func_source.split("\n"): - if decorator.startswith("@"): - decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)] - decorator_func = func_members_dict.get(decorator_func_name) - if decorator_func_name not in func_members or not decorator_func: + try: + import ast + + parsed_decorated = ast.parse(decorated_func_source) + for body_elem in parsed_decorated.body: + if not isinstance(body_elem, ast.FunctionDef): continue - decorated_func_source = CreateFromFunction.__get_source_with_decorators( - decorator_func, - original_module=original_module, - sanitize_function=sanitize_function - ) + "\n\n" + decorated_func_source + for decorator in body_elem.decorator_list: + name = None + if isinstance(decorator, ast.Name): + name = decorator.id + elif isinstance(decorator, ast.Call): + name = decorator.func.id + if not name: + continue + decorator_func = func_members_dict.get(name) + if name not in func_members or not decorator_func: + continue + decorated_func_source = CreateFromFunction.__get_source_with_decorators( + decorator_func, + original_module=original_module, + sanitize_function=sanitize_function + ) + "\n\n" + decorated_func_source + break + except Exception as e: + getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e)) return decorated_func_source @staticmethod def __extract_function_information(function, sanitize_function=None, skip_global_imports=False): # type: (Callable, Optional[Callable], bool) -> (str, str) - decorated = CreateFromFunction.__extract_wrapped(function) - function_name = str(decorated.__name__) if decorated else str(function.__name__) + function = CreateFromFunction._deep_extract_wrapped(function) function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function) if not skip_global_imports: - imports = CreateFromFunction.__extract_imports(decorated if decorated else function) + imports = CreateFromFunction.__extract_imports(function) else: imports = "" - return imports + "\n" + function_source, function_name + return imports + "\n" + function_source, function.__name__