mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Allow custom serialization/deserialization in pipelines
This commit is contained in:
parent
ecd905e518
commit
a9d2abae5b
@ -146,6 +146,8 @@ class PipelineController(object):
|
||||
repo=None, # type: Optional[str]
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@ -204,6 +206,28 @@ class PipelineController(object):
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
parameter of any type which is the object to be serialized. The function should return
|
||||
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
|
||||
artifacts uploaded by the pipeline will be serialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
"""
|
||||
self._nodes = {}
|
||||
self._running_nodes = []
|
||||
@ -236,6 +260,8 @@ class PipelineController(object):
|
||||
self._mock_execution = False # used for nested pipelines (eager execution)
|
||||
self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17"))
|
||||
self._last_progress_update_time = 0
|
||||
self._artifact_serialization_function = artifact_serialization_function
|
||||
self._artifact_deserialization_function = artifact_deserialization_function
|
||||
if not self._task:
|
||||
task_name = name or project or '{}'.format(datetime.now())
|
||||
if self._pipeline_as_sub_project:
|
||||
@ -990,6 +1016,7 @@ class PipelineController(object):
|
||||
auto_pickle=True, # type: bool
|
||||
preview=None, # type: Any
|
||||
wait_on_upload=False, # type: bool
|
||||
serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -1032,6 +1059,13 @@ class PipelineController(object):
|
||||
:param bool wait_on_upload: Whether the upload should be synchronous, forcing the upload to complete
|
||||
before continuing.
|
||||
|
||||
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
|
||||
parameter of any type which is the object to be serialized. The function should return
|
||||
a `bytes` or `bytearray` object, which represents the serialized object. Note that the object will be
|
||||
immediately serialized using this function, thus other serialization methods will not be used
|
||||
(e.g. `pandas.DataFrame.to_csv`), even if possible. To deserialize this artifact when getting
|
||||
it using the `Artifact.get` method, use its `deserialization_function` argument.
|
||||
|
||||
:return: The status of the upload.
|
||||
|
||||
- ``True`` - Upload succeeded.
|
||||
@ -1041,8 +1075,15 @@ class PipelineController(object):
|
||||
"""
|
||||
task = cls._get_pipeline_task()
|
||||
return task.upload_artifact(
|
||||
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload)
|
||||
name=name,
|
||||
artifact_object=artifact_object,
|
||||
metadata=metadata,
|
||||
delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle,
|
||||
preview=preview,
|
||||
wait_on_upload=wait_on_upload,
|
||||
serialization_function=serialization_function
|
||||
)
|
||||
|
||||
def stop(self, timeout=None, mark_failed=False, mark_aborted=False):
|
||||
# type: (Optional[float], bool, bool) -> ()
|
||||
@ -1236,7 +1277,9 @@ class PipelineController(object):
|
||||
output_uri=None,
|
||||
helper_functions=helper_functions,
|
||||
dry_run=True,
|
||||
task_template_header=self._task_template_header
|
||||
task_template_header=self._task_template_header,
|
||||
artifact_serialization_function=self._artifact_serialization_function,
|
||||
artifact_deserialization_function=self._artifact_deserialization_function
|
||||
)
|
||||
return task_definition
|
||||
|
||||
@ -3021,7 +3064,9 @@ class PipelineDecorator(PipelineController):
|
||||
packages=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
repo=None, # type: Optional[str]
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None # type: Optional[str]
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
@ -3076,6 +3121,28 @@ class PipelineDecorator(PipelineController):
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
parameter of any type which is the object to be serialized. The function should return
|
||||
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
|
||||
artifacts uploaded by the pipeline will be serialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
"""
|
||||
super(PipelineDecorator, self).__init__(
|
||||
name=name,
|
||||
@ -3093,7 +3160,9 @@ class PipelineDecorator(PipelineController):
|
||||
packages=packages,
|
||||
repo=repo,
|
||||
repo_branch=repo_branch,
|
||||
repo_commit=repo_commit
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function
|
||||
)
|
||||
|
||||
# if we are in eager execution, make sure parent class knows it
|
||||
@ -3356,7 +3425,9 @@ class PipelineDecorator(PipelineController):
|
||||
helper_functions=helper_functions,
|
||||
dry_run=True,
|
||||
task_template_header=self._task_template_header,
|
||||
_sanitize_function=sanitize
|
||||
_sanitize_function=sanitize,
|
||||
artifact_serialization_function=self._artifact_serialization_function,
|
||||
artifact_deserialization_function=self._artifact_deserialization_function
|
||||
)
|
||||
return task_definition
|
||||
|
||||
@ -3785,7 +3856,9 @@ class PipelineDecorator(PipelineController):
|
||||
|
||||
task = Task.get_task(_node.job.task_id())
|
||||
if return_name in task.artifacts:
|
||||
return task.artifacts[return_name].get()
|
||||
return task.artifacts[return_name].get(
|
||||
deserialization_function=cls._singleton._artifact_deserialization_function
|
||||
)
|
||||
return task.get_parameters(cast=True)[CreateFromFunction.return_section + "/" + return_name]
|
||||
|
||||
return_w = [LazyEvalWrapper(
|
||||
@ -3828,7 +3901,9 @@ class PipelineDecorator(PipelineController):
|
||||
packages=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
repo=None, # type: Optional[str]
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None # type: Optional[str]
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@ -3912,6 +3987,28 @@ class PipelineDecorator(PipelineController):
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
parameter of any type which is the object to be serialized. The function should return
|
||||
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
|
||||
artifacts uploaded by the pipeline will be serialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
"""
|
||||
def decorator_wrap(func):
|
||||
|
||||
@ -3955,7 +4052,9 @@ class PipelineDecorator(PipelineController):
|
||||
packages=packages,
|
||||
repo=repo,
|
||||
repo_branch=repo_branch,
|
||||
repo_commit=repo_commit
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function
|
||||
)
|
||||
ret_val = func(**pipeline_kwargs)
|
||||
LazyEvalWrapper.trigger_all_remote_references()
|
||||
@ -4004,7 +4103,9 @@ class PipelineDecorator(PipelineController):
|
||||
packages=packages,
|
||||
repo=repo,
|
||||
repo_branch=repo_branch,
|
||||
repo_commit=repo_commit
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function
|
||||
)
|
||||
|
||||
a_pipeline._args_map = args_map or {}
|
||||
|
@ -481,6 +481,9 @@ from clearml.automation.controller import PipelineDecorator
|
||||
task_template = """{header}
|
||||
from clearml.utilities.proxy_object import get_basic_type
|
||||
|
||||
{artifact_serialization_function_source}
|
||||
|
||||
{artifact_deserialization_function_source}
|
||||
|
||||
{function_source}
|
||||
|
||||
@ -501,7 +504,7 @@ if __name__ == '__main__':
|
||||
task_id, artifact_name = v.split('.', 1)
|
||||
parent_task = Task.get_task(task_id=task_id)
|
||||
if artifact_name in parent_task.artifacts:
|
||||
kwargs[k] = parent_task.artifacts[artifact_name].get()
|
||||
kwargs[k] = parent_task.artifacts[artifact_name].get(deserialization_function={artifact_deserialization_function_name})
|
||||
else:
|
||||
kwargs[k] = parent_task.get_parameters(cast=True)[return_section + '/' + artifact_name]
|
||||
results = {function_name}(**kwargs)
|
||||
@ -519,7 +522,8 @@ if __name__ == '__main__':
|
||||
task.upload_artifact(
|
||||
name=name,
|
||||
artifact_object=artifact,
|
||||
extension_name='.pkl' if isinstance(artifact, dict) else None
|
||||
extension_name='.pkl' if isinstance(artifact, dict) else None,
|
||||
serialization_function={artifact_serialization_function_name}
|
||||
)
|
||||
if parameters:
|
||||
task._set_parameters(parameters, __parameters_types=parameters_types, __update=True)
|
||||
@ -548,6 +552,8 @@ if __name__ == '__main__':
|
||||
helper_functions=None, # type: Optional[Sequence[Callable]]
|
||||
dry_run=False, # type: bool
|
||||
task_template_header=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
|
||||
):
|
||||
@ -609,6 +615,27 @@ if __name__ == '__main__':
|
||||
for the standalone function Task.
|
||||
:param dry_run: If True, do not create the Task, but return a dict of the Task's definitions
|
||||
:param task_template_header: A string placed at the top of the task's code
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
parameter of any type which is the object to be serialized. The function should return
|
||||
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
|
||||
artifacts uploaded by the pipeline will be serialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
All relevant imports must be done in this function. For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
:param _sanitize_function: Sanitization function for the function string.
|
||||
:param _sanitize_helper_functions: Sanitization function for the helper function string.
|
||||
:return: Newly created Task object
|
||||
@ -622,18 +649,28 @@ if __name__ == '__main__':
|
||||
assert (not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict)))
|
||||
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
|
||||
|
||||
function_name = str(a_function.__name__)
|
||||
function_source = inspect.getsource(a_function)
|
||||
if _sanitize_function:
|
||||
function_source = _sanitize_function(function_source)
|
||||
function_source = cls.__sanitize_remove_type_hints(function_source)
|
||||
function_source, function_name = CreateFromFunction.__extract_function_information(
|
||||
a_function, sanitize_function=_sanitize_function
|
||||
)
|
||||
|
||||
# add helper functions on top.
|
||||
for f in (helper_functions or []):
|
||||
f_source = inspect.getsource(f)
|
||||
if _sanitize_helper_functions:
|
||||
f_source = _sanitize_helper_functions(f_source)
|
||||
function_source = cls.__sanitize_remove_type_hints(f_source) + '\n\n' + function_source
|
||||
helper_function_source, _ = CreateFromFunction.__extract_function_information(
|
||||
f, sanitize_function=_sanitize_helper_functions
|
||||
)
|
||||
function_source = helper_function_source + "\n\n" + function_source
|
||||
|
||||
artifact_serialization_function_source, artifact_serialization_function_name = (
|
||||
CreateFromFunction.__extract_function_information(artifact_serialization_function)
|
||||
if artifact_serialization_function
|
||||
else ("", "None")
|
||||
)
|
||||
|
||||
artifact_deserialization_function_source, artifact_deserialization_function_name = (
|
||||
CreateFromFunction.__extract_function_information(artifact_deserialization_function)
|
||||
if artifact_deserialization_function
|
||||
else ("", "None")
|
||||
)
|
||||
|
||||
function_input_artifacts = function_input_artifacts or dict()
|
||||
# verify artifact kwargs:
|
||||
@ -686,6 +723,10 @@ if __name__ == '__main__':
|
||||
function_name=function_name,
|
||||
function_return=function_return,
|
||||
return_section=cls.return_section,
|
||||
artifact_serialization_function_source=artifact_serialization_function_source,
|
||||
artifact_serialization_function_name=artifact_serialization_function_name,
|
||||
artifact_deserialization_function_source=artifact_deserialization_function_source,
|
||||
artifact_deserialization_function_name=artifact_deserialization_function_name
|
||||
)
|
||||
|
||||
temp_dir = repo if repo and os.path.isdir(repo) else None
|
||||
@ -784,3 +825,15 @@ if __name__ == '__main__':
|
||||
except Exception:
|
||||
# just in case we failed parsing.
|
||||
return function_source
|
||||
|
||||
@staticmethod
|
||||
def __extract_function_information(function, sanitize_function=None):
|
||||
# type: (Callable, Optional[Callable]) -> (str, str)
|
||||
function_name = str(function.__name__)
|
||||
function_source = inspect.getsource(function)
|
||||
if sanitize_function:
|
||||
function_source = sanitize_function(function_source)
|
||||
|
||||
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
||||
|
||||
return function_source, function_name
|
Loading…
Reference in New Issue
Block a user