mirror of
https://github.com/clearml/clearml
synced 2025-04-20 22:36:58 +00:00
Fix decorated pipeline steps
This commit is contained in:
parent
801c7b4cd4
commit
d993fa75e4
@ -3919,10 +3919,12 @@ class PipelineDecorator(PipelineController):
|
|||||||
:return: function wrapper
|
:return: function wrapper
|
||||||
"""
|
"""
|
||||||
def decorator_wrap(func):
|
def decorator_wrap(func):
|
||||||
_name = name or str(func.__name__)
|
# noinspection PyProtectedMember
|
||||||
|
unwrapped_func = CreateFromFunction._deep_extract_wrapped(func)
|
||||||
|
_name = name or str(unwrapped_func.__name__)
|
||||||
function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values]
|
function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values]
|
||||||
|
|
||||||
inspect_func = inspect.getfullargspec(func)
|
inspect_func = inspect.getfullargspec(unwrapped_func)
|
||||||
# add default argument values
|
# add default argument values
|
||||||
if inspect_func.args:
|
if inspect_func.args:
|
||||||
default_values = list(inspect_func.defaults or [])
|
default_values = list(inspect_func.defaults or [])
|
||||||
|
@ -877,7 +877,7 @@ if __name__ == '__main__':
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __extract_wrapped(decorated):
|
def _extract_wrapped(decorated):
|
||||||
if not decorated.__closure__:
|
if not decorated.__closure__:
|
||||||
return None
|
return None
|
||||||
closure = (c.cell_contents for c in decorated.__closure__)
|
closure = (c.cell_contents for c in decorated.__closure__)
|
||||||
@ -885,6 +885,15 @@ if __name__ == '__main__':
|
|||||||
return None
|
return None
|
||||||
return next((c for c in closure if inspect.isfunction(c)), None)
|
return next((c for c in closure if inspect.isfunction(c)), None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _deep_extract_wrapped(decorated):
|
||||||
|
while True:
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
func = CreateFromFunction._extract_wrapped(decorated)
|
||||||
|
if not func:
|
||||||
|
return decorated
|
||||||
|
decorated = func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __sanitize(func_source, sanitize_function=None):
|
def __sanitize(func_source, sanitize_function=None):
|
||||||
if sanitize_function:
|
if sanitize_function:
|
||||||
@ -918,34 +927,46 @@ if __name__ == '__main__':
|
|||||||
name = getattr(original_module, "__name__", original_module)
|
name = getattr(original_module, "__name__", original_module)
|
||||||
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
|
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
|
||||||
func_members_dict = {}
|
func_members_dict = {}
|
||||||
decorated_func = CreateFromFunction.__extract_wrapped(func)
|
decorated_func = CreateFromFunction._deep_extract_wrapped(func)
|
||||||
if not decorated_func:
|
|
||||||
return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function)
|
|
||||||
decorated_func_source = CreateFromFunction.__sanitize(
|
decorated_func_source = CreateFromFunction.__sanitize(
|
||||||
inspect.getsource(decorated_func),
|
inspect.getsource(decorated_func),
|
||||||
sanitize_function=sanitize_function
|
sanitize_function=sanitize_function
|
||||||
)
|
)
|
||||||
for decorator in decorated_func_source.split("\n"):
|
try:
|
||||||
if decorator.startswith("@"):
|
import ast
|
||||||
decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)]
|
|
||||||
decorator_func = func_members_dict.get(decorator_func_name)
|
parsed_decorated = ast.parse(decorated_func_source)
|
||||||
if decorator_func_name not in func_members or not decorator_func:
|
for body_elem in parsed_decorated.body:
|
||||||
|
if not isinstance(body_elem, ast.FunctionDef):
|
||||||
continue
|
continue
|
||||||
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
for decorator in body_elem.decorator_list:
|
||||||
decorator_func,
|
name = None
|
||||||
original_module=original_module,
|
if isinstance(decorator, ast.Name):
|
||||||
sanitize_function=sanitize_function
|
name = decorator.id
|
||||||
) + "\n\n" + decorated_func_source
|
elif isinstance(decorator, ast.Call):
|
||||||
|
name = decorator.func.id
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
decorator_func = func_members_dict.get(name)
|
||||||
|
if name not in func_members or not decorator_func:
|
||||||
|
continue
|
||||||
|
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
||||||
|
decorator_func,
|
||||||
|
original_module=original_module,
|
||||||
|
sanitize_function=sanitize_function
|
||||||
|
) + "\n\n" + decorated_func_source
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e))
|
||||||
return decorated_func_source
|
return decorated_func_source
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
|
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
|
||||||
# type: (Callable, Optional[Callable], bool) -> (str, str)
|
# type: (Callable, Optional[Callable], bool) -> (str, str)
|
||||||
decorated = CreateFromFunction.__extract_wrapped(function)
|
function = CreateFromFunction._deep_extract_wrapped(function)
|
||||||
function_name = str(decorated.__name__) if decorated else str(function.__name__)
|
|
||||||
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
|
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
|
||||||
if not skip_global_imports:
|
if not skip_global_imports:
|
||||||
imports = CreateFromFunction.__extract_imports(decorated if decorated else function)
|
imports = CreateFromFunction.__extract_imports(function)
|
||||||
else:
|
else:
|
||||||
imports = ""
|
imports = ""
|
||||||
return imports + "\n" + function_source, function_name
|
return imports + "\n" + function_source, function.__name__
|
||||||
|
Loading…
Reference in New Issue
Block a user