mirror of
https://github.com/clearml/clearml
synced 2025-04-03 04:21:03 +00:00
Add support for decorated pipeline steps (#1154)
This commit is contained in:
parent
cf39029cb2
commit
4bd6b20a62
@ -177,7 +177,8 @@ class PipelineController(object):
|
||||
always_create_from_code=True, # type: bool
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@ -267,6 +268,9 @@ class PipelineController(object):
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
:param skip_global_imports: If True, global imports will not be included in the steps' execution when creating
|
||||
the steps from a functions, otherwise all global imports will be automatically imported in a safe manner at
|
||||
the beginning of each step’s execution. Default is False
|
||||
"""
|
||||
if auto_version_bump is not None:
|
||||
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
|
||||
@ -303,6 +307,7 @@ class PipelineController(object):
|
||||
self._last_progress_update_time = 0
|
||||
self._artifact_serialization_function = artifact_serialization_function
|
||||
self._artifact_deserialization_function = artifact_deserialization_function
|
||||
self._skip_global_imports = skip_global_imports
|
||||
if not self._task:
|
||||
task_name = name or project or '{}'.format(datetime.now())
|
||||
if self._pipeline_as_sub_project:
|
||||
@ -1521,7 +1526,8 @@ class PipelineController(object):
|
||||
dry_run=True,
|
||||
task_template_header=self._task_template_header,
|
||||
artifact_serialization_function=self._artifact_serialization_function,
|
||||
artifact_deserialization_function=self._artifact_deserialization_function
|
||||
artifact_deserialization_function=self._artifact_deserialization_function,
|
||||
skip_global_imports=self._skip_global_imports
|
||||
)
|
||||
return task_definition
|
||||
|
||||
@ -3316,7 +3322,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
@ -3399,6 +3406,9 @@ class PipelineDecorator(PipelineController):
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
|
||||
global imports will be automatically imported in a safe manner at the beginning of each step’s execution.
|
||||
Default is False
|
||||
"""
|
||||
super(PipelineDecorator, self).__init__(
|
||||
name=name,
|
||||
@ -3420,7 +3430,8 @@ class PipelineDecorator(PipelineController):
|
||||
always_create_from_code=False,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
output_uri=output_uri,
|
||||
skip_global_imports=skip_global_imports
|
||||
)
|
||||
|
||||
# if we are in eager execution, make sure parent class knows it
|
||||
@ -3686,7 +3697,8 @@ class PipelineDecorator(PipelineController):
|
||||
task_template_header=self._task_template_header,
|
||||
_sanitize_function=sanitize,
|
||||
artifact_serialization_function=self._artifact_serialization_function,
|
||||
artifact_deserialization_function=self._artifact_deserialization_function
|
||||
artifact_deserialization_function=self._artifact_deserialization_function,
|
||||
skip_global_imports=self._skip_global_imports
|
||||
)
|
||||
return task_definition
|
||||
|
||||
@ -4173,7 +4185,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@ -4287,6 +4300,9 @@ class PipelineDecorator(PipelineController):
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
|
||||
global imports will be automatically imported in a safe manner at the beginning of each step’s execution.
|
||||
Default is False
|
||||
"""
|
||||
def decorator_wrap(func):
|
||||
|
||||
@ -4333,7 +4349,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
output_uri=output_uri,
|
||||
skip_global_imports=skip_global_imports
|
||||
)
|
||||
ret_val = func(**pipeline_kwargs)
|
||||
LazyEvalWrapper.trigger_all_remote_references()
|
||||
@ -4385,7 +4402,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
output_uri=output_uri,
|
||||
skip_global_imports=skip_global_imports
|
||||
)
|
||||
|
||||
a_pipeline._args_map = args_map or {}
|
||||
|
@ -24,7 +24,7 @@ class CreateAndPopulate(object):
|
||||
":" \
|
||||
"(?P<path>{regular}.*)?" \
|
||||
"$" \
|
||||
.format(
|
||||
.format(
|
||||
regular=r"[^/@:#]"
|
||||
)
|
||||
|
||||
@ -199,9 +199,9 @@ class CreateAndPopulate(object):
|
||||
|
||||
# if there is nothing to populate, return
|
||||
if not any([
|
||||
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
|
||||
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
|
||||
):
|
||||
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
|
||||
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
|
||||
):
|
||||
return task
|
||||
|
||||
# clear the script section
|
||||
@ -219,7 +219,7 @@ class CreateAndPopulate(object):
|
||||
if self.cwd:
|
||||
self.cwd = self.cwd
|
||||
cwd = self.cwd if Path(self.cwd).is_dir() else (
|
||||
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
|
||||
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
|
||||
if not Path(cwd).is_dir():
|
||||
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
|
||||
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
|
||||
@ -577,6 +577,7 @@ if __name__ == '__main__':
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> Optional[Dict, Task]
|
||||
"""
|
||||
@ -659,6 +660,9 @@ if __name__ == '__main__':
|
||||
return dill.loads(bytes_)
|
||||
:param _sanitize_function: Sanitization function for the function string.
|
||||
:param _sanitize_helper_functions: Sanitization function for the helper function string.
|
||||
:param skip_global_imports: If True, the global imports will not be fetched from the function's file, otherwise
|
||||
all global imports will be automatically imported in a safe manner at the beginning of the function's
|
||||
execution. Default is False
|
||||
:return: Newly created Task object
|
||||
"""
|
||||
# not set -> equals True
|
||||
@ -671,7 +675,7 @@ if __name__ == '__main__':
|
||||
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
|
||||
|
||||
function_source, function_name = CreateFromFunction.__extract_function_information(
|
||||
a_function, sanitize_function=_sanitize_function
|
||||
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
|
||||
)
|
||||
# add helper functions on top.
|
||||
for f in (helper_functions or []):
|
||||
@ -846,11 +850,102 @@ if __name__ == '__main__':
|
||||
return function_source
|
||||
|
||||
@staticmethod
|
||||
def __extract_function_information(function, sanitize_function=None):
|
||||
# type: (Callable, Optional[Callable]) -> (str, str)
|
||||
function_name = str(function.__name__)
|
||||
function_source = inspect.getsource(function)
|
||||
def __extract_imports(func):
|
||||
def add_import_guard(import_):
|
||||
return ("try:\n "
|
||||
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
|
||||
+ "except Exception as e:\n print('Import error: ' + str(e))\n"
|
||||
)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import ast
|
||||
func_module = inspect.getmodule(func)
|
||||
source = inspect.getsource(func_module)
|
||||
source_lines = inspect.getsourcelines(func_module)[0]
|
||||
parsed_source = ast.parse(source)
|
||||
imports = []
|
||||
for parsed_source_entry in parsed_source.body:
|
||||
if isinstance(parsed_source_entry,
|
||||
(ast.Import, ast.ImportFrom)) and parsed_source_entry.col_offset == 0:
|
||||
imports.append(
|
||||
"\n".join(source_lines[parsed_source_entry.lineno - 1: parsed_source_entry.end_lineno]))
|
||||
imports = [add_import_guard(import_) for import_ in imports]
|
||||
return "\n".join(imports)
|
||||
except Exception as e:
|
||||
getLogger().warning('Could not fetch function imports: {}'.format(e))
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def __extract_wrapped(decorated):
|
||||
if not decorated.__closure__:
|
||||
return None
|
||||
closure = (c.cell_contents for c in decorated.__closure__)
|
||||
if not closure:
|
||||
return None
|
||||
return next((c for c in closure if inspect.isfunction(c)), None)
|
||||
|
||||
@staticmethod
|
||||
def __sanitize(func_source, sanitize_function=None):
|
||||
if sanitize_function:
|
||||
function_source = sanitize_function(function_source)
|
||||
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
||||
return function_source, function_name
|
||||
func_source = sanitize_function(func_source)
|
||||
return CreateFromFunction.__sanitize_remove_type_hints(func_source)
|
||||
|
||||
@staticmethod
|
||||
def __get_func_members(module):
|
||||
result = []
|
||||
try:
|
||||
import ast
|
||||
|
||||
source = inspect.getsource(module)
|
||||
parsed = ast.parse(source)
|
||||
for f in parsed.body:
|
||||
if isinstance(f, ast.FunctionDef):
|
||||
result.append(f.name)
|
||||
except Exception as e:
|
||||
name = getattr(module, "__name__", module)
|
||||
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __get_source_with_decorators(func, original_module=None, sanitize_function=None):
|
||||
if original_module is None:
|
||||
original_module = inspect.getmodule(func)
|
||||
func_members = CreateFromFunction.__get_func_members(original_module)
|
||||
try:
|
||||
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
|
||||
except Exception as e:
|
||||
name = getattr(original_module, "__name__", original_module)
|
||||
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
|
||||
func_members_dict = {}
|
||||
decorated_func = CreateFromFunction.__extract_wrapped(func)
|
||||
if not decorated_func:
|
||||
return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function)
|
||||
decorated_func_source = CreateFromFunction.__sanitize(
|
||||
inspect.getsource(decorated_func),
|
||||
sanitize_function=sanitize_function
|
||||
)
|
||||
for decorator in decorated_func_source.split("\n"):
|
||||
if decorator.startswith("@"):
|
||||
decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)]
|
||||
decorator_func = func_members_dict.get(decorator_func_name)
|
||||
if decorator_func_name not in func_members or not decorator_func:
|
||||
continue
|
||||
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
||||
decorator_func,
|
||||
original_module=original_module,
|
||||
sanitize_function=sanitize_function
|
||||
) + "\n\n" + decorated_func_source
|
||||
return decorated_func_source
|
||||
|
||||
@staticmethod
|
||||
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
|
||||
# type: (Callable, Optional[Callable], bool) -> (str, str)
|
||||
decorated = CreateFromFunction.__extract_wrapped(function)
|
||||
function_name = str(decorated.__name__) if decorated else str(function.__name__)
|
||||
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
|
||||
if not skip_global_imports:
|
||||
imports = CreateFromFunction.__extract_imports(decorated if decorated else function)
|
||||
else:
|
||||
imports = ""
|
||||
return imports + "\n" + function_source, function_name
|
||||
|
24
examples/pipeline/decorated_pipeline_step_decorators.py
Normal file
24
examples/pipeline/decorated_pipeline_step_decorators.py
Normal file
@ -0,0 +1,24 @@
|
||||
from clearml import PipelineDecorator
|
||||
|
||||
|
||||
def our_decorator(func):
|
||||
def function_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs) + 1
|
||||
return function_wrapper
|
||||
|
||||
|
||||
@PipelineDecorator.component()
|
||||
@our_decorator
|
||||
def step():
|
||||
return 1
|
||||
|
||||
|
||||
@PipelineDecorator.pipeline(name="test_decorated", project="test_decorated")
|
||||
def pipeline():
|
||||
result = step()
|
||||
assert result == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PipelineDecorator.run_locally()
|
||||
pipeline()
|
27
examples/pipeline/decorated_pipeline_step_functions.py
Normal file
27
examples/pipeline/decorated_pipeline_step_functions.py
Normal file
@ -0,0 +1,27 @@
|
||||
from clearml import PipelineController
|
||||
|
||||
|
||||
def our_decorator(func):
|
||||
def function_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs) + 1
|
||||
return function_wrapper
|
||||
|
||||
|
||||
@our_decorator
|
||||
def step():
|
||||
return 1
|
||||
|
||||
|
||||
def evaluate(step_return):
|
||||
assert step_return == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = PipelineController(name="test_decorated", project="test_decorated")
|
||||
pipeline.add_function_step(name="step", function=step, function_return=["step_return"])
|
||||
pipeline.add_function_step(
|
||||
name="evaluate",
|
||||
function=evaluate,
|
||||
function_kwargs=dict(step_return='${step.step_return}')
|
||||
)
|
||||
pipeline.start_locally(run_pipeline_steps_locally=True)
|
Loading…
Reference in New Issue
Block a user