Add support for decorated pipeline steps (#1154)

This commit is contained in:
allegroai 2024-01-03 22:18:39 +02:00
parent cf39029cb2
commit 4bd6b20a62
4 changed files with 185 additions and 21 deletions

View File

@ -177,7 +177,8 @@ class PipelineController(object):
always_create_from_code=True, # type: bool
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
output_uri=None, # type: Optional[Union[str, bool]]
skip_global_imports=False # type: bool
):
# type: (...) -> None
"""
@ -267,6 +268,9 @@ class PipelineController(object):
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
:param skip_global_imports: If True, global imports will not be included in the steps' execution when creating
the steps from a functions, otherwise all global imports will be automatically imported in a safe manner at
the beginning of each steps execution. Default is False
"""
if auto_version_bump is not None:
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
@ -303,6 +307,7 @@ class PipelineController(object):
self._last_progress_update_time = 0
self._artifact_serialization_function = artifact_serialization_function
self._artifact_deserialization_function = artifact_deserialization_function
self._skip_global_imports = skip_global_imports
if not self._task:
task_name = name or project or '{}'.format(datetime.now())
if self._pipeline_as_sub_project:
@ -1521,7 +1526,8 @@ class PipelineController(object):
dry_run=True,
task_template_header=self._task_template_header,
artifact_serialization_function=self._artifact_serialization_function,
artifact_deserialization_function=self._artifact_deserialization_function
artifact_deserialization_function=self._artifact_deserialization_function,
skip_global_imports=self._skip_global_imports
)
return task_definition
@ -3316,7 +3322,8 @@ class PipelineDecorator(PipelineController):
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
output_uri=None, # type: Optional[Union[str, bool]]
skip_global_imports=False # type: bool
):
# type: (...) -> ()
"""
@ -3399,6 +3406,9 @@ class PipelineDecorator(PipelineController):
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
global imports will be automatically imported in a safe manner at the beginning of each steps execution.
Default is False
"""
super(PipelineDecorator, self).__init__(
name=name,
@ -3420,7 +3430,8 @@ class PipelineDecorator(PipelineController):
always_create_from_code=False,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
output_uri=output_uri,
skip_global_imports=skip_global_imports
)
# if we are in eager execution, make sure parent class knows it
@ -3686,7 +3697,8 @@ class PipelineDecorator(PipelineController):
task_template_header=self._task_template_header,
_sanitize_function=sanitize,
artifact_serialization_function=self._artifact_serialization_function,
artifact_deserialization_function=self._artifact_deserialization_function
artifact_deserialization_function=self._artifact_deserialization_function,
skip_global_imports=self._skip_global_imports
)
return task_definition
@ -4173,7 +4185,8 @@ class PipelineDecorator(PipelineController):
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
output_uri=None, # type: Optional[Union[str, bool]]
skip_global_imports=False # type: bool
):
# type: (...) -> Callable
"""
@ -4287,6 +4300,9 @@ class PipelineDecorator(PipelineController):
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
global imports will be automatically imported in a safe manner at the beginning of each steps execution.
Default is False
"""
def decorator_wrap(func):
@ -4333,7 +4349,8 @@ class PipelineDecorator(PipelineController):
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
output_uri=output_uri,
skip_global_imports=skip_global_imports
)
ret_val = func(**pipeline_kwargs)
LazyEvalWrapper.trigger_all_remote_references()
@ -4385,7 +4402,8 @@ class PipelineDecorator(PipelineController):
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
output_uri=output_uri,
skip_global_imports=skip_global_imports
)
a_pipeline._args_map = args_map or {}

View File

@ -24,7 +24,7 @@ class CreateAndPopulate(object):
":" \
"(?P<path>{regular}.*)?" \
"$" \
.format(
.format(
regular=r"[^/@:#]"
)
@ -199,9 +199,9 @@ class CreateAndPopulate(object):
# if there is nothing to populate, return
if not any([
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
):
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
):
return task
# clear the script section
@ -219,7 +219,7 @@ class CreateAndPopulate(object):
if self.cwd:
self.cwd = self.cwd
cwd = self.cwd if Path(self.cwd).is_dir() else (
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
if not Path(cwd).is_dir():
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
@ -577,6 +577,7 @@ if __name__ == '__main__':
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
_sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
skip_global_imports=False # type: bool
):
# type: (...) -> Optional[Dict, Task]
"""
@ -659,6 +660,9 @@ if __name__ == '__main__':
return dill.loads(bytes_)
:param _sanitize_function: Sanitization function for the function string.
:param _sanitize_helper_functions: Sanitization function for the helper function string.
:param skip_global_imports: If True, the global imports will not be fetched from the function's file, otherwise
all global imports will be automatically imported in a safe manner at the beginning of the function's
execution. Default is False
:return: Newly created Task object
"""
# not set -> equals True
@ -671,7 +675,7 @@ if __name__ == '__main__':
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
function_source, function_name = CreateFromFunction.__extract_function_information(
a_function, sanitize_function=_sanitize_function
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
)
# add helper functions on top.
for f in (helper_functions or []):
@ -846,11 +850,102 @@ if __name__ == '__main__':
return function_source
@staticmethod
def __extract_function_information(function, sanitize_function=None):
# type: (Callable, Optional[Callable]) -> (str, str)
function_name = str(function.__name__)
function_source = inspect.getsource(function)
def __extract_imports(func):
def add_import_guard(import_):
return ("try:\n "
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
+ "except Exception as e:\n print('Import error: ' + str(e))\n"
)
# noinspection PyBroadException
try:
import ast
func_module = inspect.getmodule(func)
source = inspect.getsource(func_module)
source_lines = inspect.getsourcelines(func_module)[0]
parsed_source = ast.parse(source)
imports = []
for parsed_source_entry in parsed_source.body:
if isinstance(parsed_source_entry,
(ast.Import, ast.ImportFrom)) and parsed_source_entry.col_offset == 0:
imports.append(
"\n".join(source_lines[parsed_source_entry.lineno - 1: parsed_source_entry.end_lineno]))
imports = [add_import_guard(import_) for import_ in imports]
return "\n".join(imports)
except Exception as e:
getLogger().warning('Could not fetch function imports: {}'.format(e))
return ""
@staticmethod
def __extract_wrapped(decorated):
if not decorated.__closure__:
return None
closure = (c.cell_contents for c in decorated.__closure__)
if not closure:
return None
return next((c for c in closure if inspect.isfunction(c)), None)
@staticmethod
def __sanitize(func_source, sanitize_function=None):
if sanitize_function:
function_source = sanitize_function(function_source)
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
return function_source, function_name
func_source = sanitize_function(func_source)
return CreateFromFunction.__sanitize_remove_type_hints(func_source)
@staticmethod
def __get_func_members(module):
result = []
try:
import ast
source = inspect.getsource(module)
parsed = ast.parse(source)
for f in parsed.body:
if isinstance(f, ast.FunctionDef):
result.append(f.name)
except Exception as e:
name = getattr(module, "__name__", module)
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e))
return result
@staticmethod
def __get_source_with_decorators(func, original_module=None, sanitize_function=None):
if original_module is None:
original_module = inspect.getmodule(func)
func_members = CreateFromFunction.__get_func_members(original_module)
try:
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
except Exception as e:
name = getattr(original_module, "__name__", original_module)
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
func_members_dict = {}
decorated_func = CreateFromFunction.__extract_wrapped(func)
if not decorated_func:
return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function)
decorated_func_source = CreateFromFunction.__sanitize(
inspect.getsource(decorated_func),
sanitize_function=sanitize_function
)
for decorator in decorated_func_source.split("\n"):
if decorator.startswith("@"):
decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)]
decorator_func = func_members_dict.get(decorator_func_name)
if decorator_func_name not in func_members or not decorator_func:
continue
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
decorator_func,
original_module=original_module,
sanitize_function=sanitize_function
) + "\n\n" + decorated_func_source
return decorated_func_source
@staticmethod
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
# type: (Callable, Optional[Callable], bool) -> (str, str)
decorated = CreateFromFunction.__extract_wrapped(function)
function_name = str(decorated.__name__) if decorated else str(function.__name__)
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
if not skip_global_imports:
imports = CreateFromFunction.__extract_imports(decorated if decorated else function)
else:
imports = ""
return imports + "\n" + function_source, function_name

View File

@ -0,0 +1,24 @@
from clearml import PipelineDecorator
def our_decorator(func):
def function_wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return function_wrapper
@PipelineDecorator.component()
@our_decorator
def step():
return 1
@PipelineDecorator.pipeline(name="test_decorated", project="test_decorated")
def pipeline():
result = step()
assert result == 2
if __name__ == "__main__":
PipelineDecorator.run_locally()
pipeline()

View File

@ -0,0 +1,27 @@
from clearml import PipelineController
def our_decorator(func):
def function_wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return function_wrapper
@our_decorator
def step():
return 1
def evaluate(step_return):
assert step_return == 2
if __name__ == "__main__":
pipeline = PipelineController(name="test_decorated", project="test_decorated")
pipeline.add_function_step(name="step", function=step, function_return=["step_return"])
pipeline.add_function_step(
name="evaluate",
function=evaluate,
function_kwargs=dict(step_return='${step.step_return}')
)
pipeline.start_locally(run_pipeline_steps_locally=True)