mirror of
https://github.com/clearml/clearml
synced 2025-04-07 14:14:28 +00:00
Add support for decorated pipeline steps (#1154)
This commit is contained in:
parent
cf39029cb2
commit
4bd6b20a62
@ -177,7 +177,8 @@ class PipelineController(object):
|
|||||||
always_create_from_code=True, # type: bool
|
always_create_from_code=True, # type: bool
|
||||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||||
output_uri=None # type: Optional[Union[str, bool]]
|
output_uri=None, # type: Optional[Union[str, bool]]
|
||||||
|
skip_global_imports=False # type: bool
|
||||||
):
|
):
|
||||||
# type: (...) -> None
|
# type: (...) -> None
|
||||||
"""
|
"""
|
||||||
@ -267,6 +268,9 @@ class PipelineController(object):
|
|||||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||||
The `output_uri` of this pipeline's steps will default to this value.
|
The `output_uri` of this pipeline's steps will default to this value.
|
||||||
|
:param skip_global_imports: If True, global imports will not be included in the steps' execution when creating
|
||||||
|
the steps from a functions, otherwise all global imports will be automatically imported in a safe manner at
|
||||||
|
the beginning of each step’s execution. Default is False
|
||||||
"""
|
"""
|
||||||
if auto_version_bump is not None:
|
if auto_version_bump is not None:
|
||||||
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
|
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
|
||||||
@ -303,6 +307,7 @@ class PipelineController(object):
|
|||||||
self._last_progress_update_time = 0
|
self._last_progress_update_time = 0
|
||||||
self._artifact_serialization_function = artifact_serialization_function
|
self._artifact_serialization_function = artifact_serialization_function
|
||||||
self._artifact_deserialization_function = artifact_deserialization_function
|
self._artifact_deserialization_function = artifact_deserialization_function
|
||||||
|
self._skip_global_imports = skip_global_imports
|
||||||
if not self._task:
|
if not self._task:
|
||||||
task_name = name or project or '{}'.format(datetime.now())
|
task_name = name or project or '{}'.format(datetime.now())
|
||||||
if self._pipeline_as_sub_project:
|
if self._pipeline_as_sub_project:
|
||||||
@ -1521,7 +1526,8 @@ class PipelineController(object):
|
|||||||
dry_run=True,
|
dry_run=True,
|
||||||
task_template_header=self._task_template_header,
|
task_template_header=self._task_template_header,
|
||||||
artifact_serialization_function=self._artifact_serialization_function,
|
artifact_serialization_function=self._artifact_serialization_function,
|
||||||
artifact_deserialization_function=self._artifact_deserialization_function
|
artifact_deserialization_function=self._artifact_deserialization_function,
|
||||||
|
skip_global_imports=self._skip_global_imports
|
||||||
)
|
)
|
||||||
return task_definition
|
return task_definition
|
||||||
|
|
||||||
@ -3316,7 +3322,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
repo_commit=None, # type: Optional[str]
|
repo_commit=None, # type: Optional[str]
|
||||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||||
output_uri=None # type: Optional[Union[str, bool]]
|
output_uri=None, # type: Optional[Union[str, bool]]
|
||||||
|
skip_global_imports=False # type: bool
|
||||||
):
|
):
|
||||||
# type: (...) -> ()
|
# type: (...) -> ()
|
||||||
"""
|
"""
|
||||||
@ -3399,6 +3406,9 @@ class PipelineDecorator(PipelineController):
|
|||||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||||
The `output_uri` of this pipeline's steps will default to this value.
|
The `output_uri` of this pipeline's steps will default to this value.
|
||||||
|
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
|
||||||
|
global imports will be automatically imported in a safe manner at the beginning of each step’s execution.
|
||||||
|
Default is False
|
||||||
"""
|
"""
|
||||||
super(PipelineDecorator, self).__init__(
|
super(PipelineDecorator, self).__init__(
|
||||||
name=name,
|
name=name,
|
||||||
@ -3420,7 +3430,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
always_create_from_code=False,
|
always_create_from_code=False,
|
||||||
artifact_serialization_function=artifact_serialization_function,
|
artifact_serialization_function=artifact_serialization_function,
|
||||||
artifact_deserialization_function=artifact_deserialization_function,
|
artifact_deserialization_function=artifact_deserialization_function,
|
||||||
output_uri=output_uri
|
output_uri=output_uri,
|
||||||
|
skip_global_imports=skip_global_imports
|
||||||
)
|
)
|
||||||
|
|
||||||
# if we are in eager execution, make sure parent class knows it
|
# if we are in eager execution, make sure parent class knows it
|
||||||
@ -3686,7 +3697,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
task_template_header=self._task_template_header,
|
task_template_header=self._task_template_header,
|
||||||
_sanitize_function=sanitize,
|
_sanitize_function=sanitize,
|
||||||
artifact_serialization_function=self._artifact_serialization_function,
|
artifact_serialization_function=self._artifact_serialization_function,
|
||||||
artifact_deserialization_function=self._artifact_deserialization_function
|
artifact_deserialization_function=self._artifact_deserialization_function,
|
||||||
|
skip_global_imports=self._skip_global_imports
|
||||||
)
|
)
|
||||||
return task_definition
|
return task_definition
|
||||||
|
|
||||||
@ -4173,7 +4185,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
repo_commit=None, # type: Optional[str]
|
repo_commit=None, # type: Optional[str]
|
||||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||||
output_uri=None # type: Optional[Union[str, bool]]
|
output_uri=None, # type: Optional[Union[str, bool]]
|
||||||
|
skip_global_imports=False # type: bool
|
||||||
):
|
):
|
||||||
# type: (...) -> Callable
|
# type: (...) -> Callable
|
||||||
"""
|
"""
|
||||||
@ -4287,6 +4300,9 @@ class PipelineDecorator(PipelineController):
|
|||||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||||
The `output_uri` of this pipeline's steps will default to this value.
|
The `output_uri` of this pipeline's steps will default to this value.
|
||||||
|
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
|
||||||
|
global imports will be automatically imported in a safe manner at the beginning of each step’s execution.
|
||||||
|
Default is False
|
||||||
"""
|
"""
|
||||||
def decorator_wrap(func):
|
def decorator_wrap(func):
|
||||||
|
|
||||||
@ -4333,7 +4349,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
repo_commit=repo_commit,
|
repo_commit=repo_commit,
|
||||||
artifact_serialization_function=artifact_serialization_function,
|
artifact_serialization_function=artifact_serialization_function,
|
||||||
artifact_deserialization_function=artifact_deserialization_function,
|
artifact_deserialization_function=artifact_deserialization_function,
|
||||||
output_uri=output_uri
|
output_uri=output_uri,
|
||||||
|
skip_global_imports=skip_global_imports
|
||||||
)
|
)
|
||||||
ret_val = func(**pipeline_kwargs)
|
ret_val = func(**pipeline_kwargs)
|
||||||
LazyEvalWrapper.trigger_all_remote_references()
|
LazyEvalWrapper.trigger_all_remote_references()
|
||||||
@ -4385,7 +4402,8 @@ class PipelineDecorator(PipelineController):
|
|||||||
repo_commit=repo_commit,
|
repo_commit=repo_commit,
|
||||||
artifact_serialization_function=artifact_serialization_function,
|
artifact_serialization_function=artifact_serialization_function,
|
||||||
artifact_deserialization_function=artifact_deserialization_function,
|
artifact_deserialization_function=artifact_deserialization_function,
|
||||||
output_uri=output_uri
|
output_uri=output_uri,
|
||||||
|
skip_global_imports=skip_global_imports
|
||||||
)
|
)
|
||||||
|
|
||||||
a_pipeline._args_map = args_map or {}
|
a_pipeline._args_map = args_map or {}
|
||||||
|
@ -24,7 +24,7 @@ class CreateAndPopulate(object):
|
|||||||
":" \
|
":" \
|
||||||
"(?P<path>{regular}.*)?" \
|
"(?P<path>{regular}.*)?" \
|
||||||
"$" \
|
"$" \
|
||||||
.format(
|
.format(
|
||||||
regular=r"[^/@:#]"
|
regular=r"[^/@:#]"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,9 +199,9 @@ class CreateAndPopulate(object):
|
|||||||
|
|
||||||
# if there is nothing to populate, return
|
# if there is nothing to populate, return
|
||||||
if not any([
|
if not any([
|
||||||
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
|
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
|
||||||
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
|
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
|
||||||
):
|
):
|
||||||
return task
|
return task
|
||||||
|
|
||||||
# clear the script section
|
# clear the script section
|
||||||
@ -219,7 +219,7 @@ class CreateAndPopulate(object):
|
|||||||
if self.cwd:
|
if self.cwd:
|
||||||
self.cwd = self.cwd
|
self.cwd = self.cwd
|
||||||
cwd = self.cwd if Path(self.cwd).is_dir() else (
|
cwd = self.cwd if Path(self.cwd).is_dir() else (
|
||||||
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
|
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
|
||||||
if not Path(cwd).is_dir():
|
if not Path(cwd).is_dir():
|
||||||
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
|
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
|
||||||
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
|
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
|
||||||
@ -577,6 +577,7 @@ if __name__ == '__main__':
|
|||||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||||
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||||
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
|
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
|
||||||
|
skip_global_imports=False # type: bool
|
||||||
):
|
):
|
||||||
# type: (...) -> Optional[Dict, Task]
|
# type: (...) -> Optional[Dict, Task]
|
||||||
"""
|
"""
|
||||||
@ -659,6 +660,9 @@ if __name__ == '__main__':
|
|||||||
return dill.loads(bytes_)
|
return dill.loads(bytes_)
|
||||||
:param _sanitize_function: Sanitization function for the function string.
|
:param _sanitize_function: Sanitization function for the function string.
|
||||||
:param _sanitize_helper_functions: Sanitization function for the helper function string.
|
:param _sanitize_helper_functions: Sanitization function for the helper function string.
|
||||||
|
:param skip_global_imports: If True, the global imports will not be fetched from the function's file, otherwise
|
||||||
|
all global imports will be automatically imported in a safe manner at the beginning of the function's
|
||||||
|
execution. Default is False
|
||||||
:return: Newly created Task object
|
:return: Newly created Task object
|
||||||
"""
|
"""
|
||||||
# not set -> equals True
|
# not set -> equals True
|
||||||
@ -671,7 +675,7 @@ if __name__ == '__main__':
|
|||||||
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
|
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
|
||||||
|
|
||||||
function_source, function_name = CreateFromFunction.__extract_function_information(
|
function_source, function_name = CreateFromFunction.__extract_function_information(
|
||||||
a_function, sanitize_function=_sanitize_function
|
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
|
||||||
)
|
)
|
||||||
# add helper functions on top.
|
# add helper functions on top.
|
||||||
for f in (helper_functions or []):
|
for f in (helper_functions or []):
|
||||||
@ -846,11 +850,102 @@ if __name__ == '__main__':
|
|||||||
return function_source
|
return function_source
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __extract_function_information(function, sanitize_function=None):
|
def __extract_imports(func):
|
||||||
# type: (Callable, Optional[Callable]) -> (str, str)
|
def add_import_guard(import_):
|
||||||
function_name = str(function.__name__)
|
return ("try:\n "
|
||||||
function_source = inspect.getsource(function)
|
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
|
||||||
|
+ "except Exception as e:\n print('Import error: ' + str(e))\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import ast
|
||||||
|
func_module = inspect.getmodule(func)
|
||||||
|
source = inspect.getsource(func_module)
|
||||||
|
source_lines = inspect.getsourcelines(func_module)[0]
|
||||||
|
parsed_source = ast.parse(source)
|
||||||
|
imports = []
|
||||||
|
for parsed_source_entry in parsed_source.body:
|
||||||
|
if isinstance(parsed_source_entry,
|
||||||
|
(ast.Import, ast.ImportFrom)) and parsed_source_entry.col_offset == 0:
|
||||||
|
imports.append(
|
||||||
|
"\n".join(source_lines[parsed_source_entry.lineno - 1: parsed_source_entry.end_lineno]))
|
||||||
|
imports = [add_import_guard(import_) for import_ in imports]
|
||||||
|
return "\n".join(imports)
|
||||||
|
except Exception as e:
|
||||||
|
getLogger().warning('Could not fetch function imports: {}'.format(e))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __extract_wrapped(decorated):
|
||||||
|
if not decorated.__closure__:
|
||||||
|
return None
|
||||||
|
closure = (c.cell_contents for c in decorated.__closure__)
|
||||||
|
if not closure:
|
||||||
|
return None
|
||||||
|
return next((c for c in closure if inspect.isfunction(c)), None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __sanitize(func_source, sanitize_function=None):
|
||||||
if sanitize_function:
|
if sanitize_function:
|
||||||
function_source = sanitize_function(function_source)
|
func_source = sanitize_function(func_source)
|
||||||
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
return CreateFromFunction.__sanitize_remove_type_hints(func_source)
|
||||||
return function_source, function_name
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_func_members(module):
|
||||||
|
result = []
|
||||||
|
try:
|
||||||
|
import ast
|
||||||
|
|
||||||
|
source = inspect.getsource(module)
|
||||||
|
parsed = ast.parse(source)
|
||||||
|
for f in parsed.body:
|
||||||
|
if isinstance(f, ast.FunctionDef):
|
||||||
|
result.append(f.name)
|
||||||
|
except Exception as e:
|
||||||
|
name = getattr(module, "__name__", module)
|
||||||
|
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __get_source_with_decorators(func, original_module=None, sanitize_function=None):
|
||||||
|
if original_module is None:
|
||||||
|
original_module = inspect.getmodule(func)
|
||||||
|
func_members = CreateFromFunction.__get_func_members(original_module)
|
||||||
|
try:
|
||||||
|
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
|
||||||
|
except Exception as e:
|
||||||
|
name = getattr(original_module, "__name__", original_module)
|
||||||
|
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
|
||||||
|
func_members_dict = {}
|
||||||
|
decorated_func = CreateFromFunction.__extract_wrapped(func)
|
||||||
|
if not decorated_func:
|
||||||
|
return CreateFromFunction.__sanitize(inspect.getsource(func), sanitize_function=sanitize_function)
|
||||||
|
decorated_func_source = CreateFromFunction.__sanitize(
|
||||||
|
inspect.getsource(decorated_func),
|
||||||
|
sanitize_function=sanitize_function
|
||||||
|
)
|
||||||
|
for decorator in decorated_func_source.split("\n"):
|
||||||
|
if decorator.startswith("@"):
|
||||||
|
decorator_func_name = decorator[1:decorator.find("(") if "(" in decorator else len(decorator)]
|
||||||
|
decorator_func = func_members_dict.get(decorator_func_name)
|
||||||
|
if decorator_func_name not in func_members or not decorator_func:
|
||||||
|
continue
|
||||||
|
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
||||||
|
decorator_func,
|
||||||
|
original_module=original_module,
|
||||||
|
sanitize_function=sanitize_function
|
||||||
|
) + "\n\n" + decorated_func_source
|
||||||
|
return decorated_func_source
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
|
||||||
|
# type: (Callable, Optional[Callable], bool) -> (str, str)
|
||||||
|
decorated = CreateFromFunction.__extract_wrapped(function)
|
||||||
|
function_name = str(decorated.__name__) if decorated else str(function.__name__)
|
||||||
|
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
|
||||||
|
if not skip_global_imports:
|
||||||
|
imports = CreateFromFunction.__extract_imports(decorated if decorated else function)
|
||||||
|
else:
|
||||||
|
imports = ""
|
||||||
|
return imports + "\n" + function_source, function_name
|
||||||
|
24
examples/pipeline/decorated_pipeline_step_decorators.py
Normal file
24
examples/pipeline/decorated_pipeline_step_decorators.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from clearml import PipelineDecorator
|
||||||
|
|
||||||
|
|
||||||
|
def our_decorator(func):
|
||||||
|
def function_wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs) + 1
|
||||||
|
return function_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@PipelineDecorator.component()
|
||||||
|
@our_decorator
|
||||||
|
def step():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@PipelineDecorator.pipeline(name="test_decorated", project="test_decorated")
|
||||||
|
def pipeline():
|
||||||
|
result = step()
|
||||||
|
assert result == 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
PipelineDecorator.run_locally()
|
||||||
|
pipeline()
|
27
examples/pipeline/decorated_pipeline_step_functions.py
Normal file
27
examples/pipeline/decorated_pipeline_step_functions.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from clearml import PipelineController
|
||||||
|
|
||||||
|
|
||||||
|
def our_decorator(func):
|
||||||
|
def function_wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs) + 1
|
||||||
|
return function_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@our_decorator
|
||||||
|
def step():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(step_return):
|
||||||
|
assert step_return == 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline = PipelineController(name="test_decorated", project="test_decorated")
|
||||||
|
pipeline.add_function_step(name="step", function=step, function_return=["step_return"])
|
||||||
|
pipeline.add_function_step(
|
||||||
|
name="evaluate",
|
||||||
|
function=evaluate,
|
||||||
|
function_kwargs=dict(step_return='${step.step_return}')
|
||||||
|
)
|
||||||
|
pipeline.start_locally(run_pipeline_steps_locally=True)
|
Loading…
Reference in New Issue
Block a user