Fix decorated pipeline steps

This commit is contained in:
allegroai 2024-01-06 12:35:06 +02:00
parent 801c7b4cd4
commit d993fa75e4
2 changed files with 43 additions and 20 deletions

View File

@ -3919,10 +3919,12 @@ class PipelineDecorator(PipelineController):
:return: function wrapper
"""
def decorator_wrap(func):
_name = name or str(func.__name__)
# noinspection PyProtectedMember
unwrapped_func = CreateFromFunction._deep_extract_wrapped(func)
_name = name or str(unwrapped_func.__name__)
function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values]
inspect_func = inspect.getfullargspec(func)
inspect_func = inspect.getfullargspec(unwrapped_func)
# add default argument values
if inspect_func.args:
default_values = list(inspect_func.defaults or [])

View File

@ -877,7 +877,7 @@ if __name__ == '__main__':
return ""
@staticmethod
def __extract_wrapped(decorated):
def _extract_wrapped(decorated):
if not decorated.__closure__:
return None
closure = (c.cell_contents for c in decorated.__closure__)
@ -885,6 +885,15 @@ if __name__ == '__main__':
return None
return next((c for c in closure if inspect.isfunction(c)), None)
@staticmethod
def _deep_extract_wrapped(decorated):
while True:
# noinspection PyProtectedMember
func = CreateFromFunction._extract_wrapped(decorated)
if not func:
return decorated
decorated = func
@staticmethod
def __sanitize(func_source, sanitize_function=None):
if sanitize_function:
@ -918,34 +927,46 @@ if __name__ == '__main__':
name = getattr(original_module, "__name__", original_module)
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
func_members_dict = {}
decorated_func = CreateFromFunction.__extract_wrapped(func)
if not decorated_func:
return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function)
decorated_func = CreateFromFunction._deep_extract_wrapped(func)
decorated_func_source = CreateFromFunction.__sanitize(
inspect.getsource(decorated_func),
sanitize_function=sanitize_function
)
for decorator in decorated_func_source.split("\n"):
if decorator.startswith("@"):
decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)]
decorator_func = func_members_dict.get(decorator_func_name)
if decorator_func_name not in func_members or not decorator_func:
try:
import ast
parsed_decorated = ast.parse(decorated_func_source)
for body_elem in parsed_decorated.body:
if not isinstance(body_elem, ast.FunctionDef):
continue
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
decorator_func,
original_module=original_module,
sanitize_function=sanitize_function
) + "\n\n" + decorated_func_source
for decorator in body_elem.decorator_list:
name = None
if isinstance(decorator, ast.Name):
name = decorator.id
elif isinstance(decorator, ast.Call):
name = decorator.func.id
if not name:
continue
decorator_func = func_members_dict.get(name)
if name not in func_members or not decorator_func:
continue
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
decorator_func,
original_module=original_module,
sanitize_function=sanitize_function
) + "\n\n" + decorated_func_source
break
except Exception as e:
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e))
return decorated_func_source
@staticmethod
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
# type: (Callable, Optional[Callable], bool) -> (str, str)
decorated = CreateFromFunction.__extract_wrapped(function)
function_name = str(decorated.__name__) if decorated else str(function.__name__)
function = CreateFromFunction._deep_extract_wrapped(function)
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
if not skip_global_imports:
imports = CreateFromFunction.__extract_imports(decorated if decorated else function)
imports = CreateFromFunction.__extract_imports(function)
else:
imports = ""
return imports + "\n" + function_source, function_name
return imports + "\n" + function_source, function.__name__