mirror of
https://github.com/clearml/clearml
synced 2025-04-16 21:42:10 +00:00
Fix decorated pipeline steps
This commit is contained in:
parent
801c7b4cd4
commit
d993fa75e4
@ -3919,10 +3919,12 @@ class PipelineDecorator(PipelineController):
|
||||
:return: function wrapper
|
||||
"""
|
||||
def decorator_wrap(func):
|
||||
_name = name or str(func.__name__)
|
||||
# noinspection PyProtectedMember
|
||||
unwrapped_func = CreateFromFunction._deep_extract_wrapped(func)
|
||||
_name = name or str(unwrapped_func.__name__)
|
||||
function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values]
|
||||
|
||||
inspect_func = inspect.getfullargspec(func)
|
||||
inspect_func = inspect.getfullargspec(unwrapped_func)
|
||||
# add default argument values
|
||||
if inspect_func.args:
|
||||
default_values = list(inspect_func.defaults or [])
|
||||
|
@ -877,7 +877,7 @@ if __name__ == '__main__':
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def __extract_wrapped(decorated):
|
||||
def _extract_wrapped(decorated):
|
||||
if not decorated.__closure__:
|
||||
return None
|
||||
closure = (c.cell_contents for c in decorated.__closure__)
|
||||
@ -885,6 +885,15 @@ if __name__ == '__main__':
|
||||
return None
|
||||
return next((c for c in closure if inspect.isfunction(c)), None)
|
||||
|
||||
@staticmethod
|
||||
def _deep_extract_wrapped(decorated):
|
||||
while True:
|
||||
# noinspection PyProtectedMember
|
||||
func = CreateFromFunction._extract_wrapped(decorated)
|
||||
if not func:
|
||||
return decorated
|
||||
decorated = func
|
||||
|
||||
@staticmethod
|
||||
def __sanitize(func_source, sanitize_function=None):
|
||||
if sanitize_function:
|
||||
@ -918,34 +927,46 @@ if __name__ == '__main__':
|
||||
name = getattr(original_module, "__name__", original_module)
|
||||
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
|
||||
func_members_dict = {}
|
||||
decorated_func = CreateFromFunction.__extract_wrapped(func)
|
||||
if not decorated_func:
|
||||
return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function)
|
||||
decorated_func = CreateFromFunction._deep_extract_wrapped(func)
|
||||
decorated_func_source = CreateFromFunction.__sanitize(
|
||||
inspect.getsource(decorated_func),
|
||||
sanitize_function=sanitize_function
|
||||
)
|
||||
for decorator in decorated_func_source.split("\n"):
|
||||
if decorator.startswith("@"):
|
||||
decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)]
|
||||
decorator_func = func_members_dict.get(decorator_func_name)
|
||||
if decorator_func_name not in func_members or not decorator_func:
|
||||
try:
|
||||
import ast
|
||||
|
||||
parsed_decorated = ast.parse(decorated_func_source)
|
||||
for body_elem in parsed_decorated.body:
|
||||
if not isinstance(body_elem, ast.FunctionDef):
|
||||
continue
|
||||
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
||||
decorator_func,
|
||||
original_module=original_module,
|
||||
sanitize_function=sanitize_function
|
||||
) + "\n\n" + decorated_func_source
|
||||
for decorator in body_elem.decorator_list:
|
||||
name = None
|
||||
if isinstance(decorator, ast.Name):
|
||||
name = decorator.id
|
||||
elif isinstance(decorator, ast.Call):
|
||||
name = decorator.func.id
|
||||
if not name:
|
||||
continue
|
||||
decorator_func = func_members_dict.get(name)
|
||||
if name not in func_members or not decorator_func:
|
||||
continue
|
||||
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
||||
decorator_func,
|
||||
original_module=original_module,
|
||||
sanitize_function=sanitize_function
|
||||
) + "\n\n" + decorated_func_source
|
||||
break
|
||||
except Exception as e:
|
||||
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e))
|
||||
return decorated_func_source
|
||||
|
||||
@staticmethod
|
||||
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
|
||||
# type: (Callable, Optional[Callable], bool) -> (str, str)
|
||||
decorated = CreateFromFunction.__extract_wrapped(function)
|
||||
function_name = str(decorated.__name__) if decorated else str(function.__name__)
|
||||
function = CreateFromFunction._deep_extract_wrapped(function)
|
||||
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
|
||||
if not skip_global_imports:
|
||||
imports = CreateFromFunction.__extract_imports(decorated if decorated else function)
|
||||
imports = CreateFromFunction.__extract_imports(function)
|
||||
else:
|
||||
imports = ""
|
||||
return imports + "\n" + function_source, function_name
|
||||
return imports + "\n" + function_source, function.__name__
|
||||
|
Loading…
Reference in New Issue
Block a user