Allow custom serialization/deserialization in pipelines

This commit is contained in:
Alex Burlacu 2023-03-23 17:38:09 +02:00
parent ecd905e518
commit a9d2abae5b
2 changed files with 175 additions and 21 deletions

View File

@ -146,6 +146,8 @@ class PipelineController(object):
repo=None, # type: Optional[str]
repo_branch=None, # type: Optional[str]
repo_commit=None # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
):
# type: (...) -> None
"""
@ -204,6 +206,28 @@ class PipelineController(object):
repo url and commit ID based on the locally cloned copy
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one
parameter of any type which is the object to be serialized. The function should return
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
artifacts uploaded by the pipeline will be serialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def serialize(obj):
import dill
return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
"""
self._nodes = {}
self._running_nodes = []
@ -236,6 +260,8 @@ class PipelineController(object):
self._mock_execution = False # used for nested pipelines (eager execution)
self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17"))
self._last_progress_update_time = 0
self._artifact_serialization_function = artifact_serialization_function
self._artifact_deserialization_function = artifact_deserialization_function
if not self._task:
task_name = name or project or '{}'.format(datetime.now())
if self._pipeline_as_sub_project:
@ -990,6 +1016,7 @@ class PipelineController(object):
auto_pickle=True, # type: bool
preview=None, # type: Any
wait_on_upload=False, # type: bool
serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
):
# type: (...) -> bool
"""
@ -1032,6 +1059,13 @@ class PipelineController(object):
:param bool wait_on_upload: Whether the upload should be synchronous, forcing the upload to complete
before continuing.
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one
parameter of any type which is the object to be serialized. The function should return
a `bytes` or `bytearray` object, which represents the serialized object. Note that the object will be
immediately serialized using this function, thus other serialization methods will not be used
(e.g. `pandas.DataFrame.to_csv`), even if possible. To deserialize this artifact when getting
it using the `Artifact.get` method, use its `deserialization_function` argument.
:return: The status of the upload.
- ``True`` - Upload succeeded.
@ -1041,8 +1075,15 @@ class PipelineController(object):
"""
task = cls._get_pipeline_task()
return task.upload_artifact(
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload)
name=name,
artifact_object=artifact_object,
metadata=metadata,
delete_after_upload=delete_after_upload,
auto_pickle=auto_pickle,
preview=preview,
wait_on_upload=wait_on_upload,
serialization_function=serialization_function
)
def stop(self, timeout=None, mark_failed=False, mark_aborted=False):
# type: (Optional[float], bool, bool) -> ()
@ -1236,7 +1277,9 @@ class PipelineController(object):
output_uri=None,
helper_functions=helper_functions,
dry_run=True,
task_template_header=self._task_template_header
task_template_header=self._task_template_header,
artifact_serialization_function=self._artifact_serialization_function,
artifact_deserialization_function=self._artifact_deserialization_function
)
return task_definition
@ -3021,7 +3064,9 @@ class PipelineDecorator(PipelineController):
packages=None, # type: Optional[Union[str, Sequence[str]]]
repo=None, # type: Optional[str]
repo_branch=None, # type: Optional[str]
repo_commit=None # type: Optional[str]
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
):
# type: (...) -> ()
"""
@ -3076,6 +3121,28 @@ class PipelineDecorator(PipelineController):
repo url and commit ID based on the locally cloned copy
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one
parameter of any type which is the object to be serialized. The function should return
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
artifacts uploaded by the pipeline will be serialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def serialize(obj):
import dill
return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
"""
super(PipelineDecorator, self).__init__(
name=name,
@ -3093,7 +3160,9 @@ class PipelineDecorator(PipelineController):
packages=packages,
repo=repo,
repo_branch=repo_branch,
repo_commit=repo_commit
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
)
# if we are in eager execution, make sure parent class knows it
@ -3356,7 +3425,9 @@ class PipelineDecorator(PipelineController):
helper_functions=helper_functions,
dry_run=True,
task_template_header=self._task_template_header,
_sanitize_function=sanitize
_sanitize_function=sanitize,
artifact_serialization_function=self._artifact_serialization_function,
artifact_deserialization_function=self._artifact_deserialization_function
)
return task_definition
@ -3785,7 +3856,9 @@ class PipelineDecorator(PipelineController):
task = Task.get_task(_node.job.task_id())
if return_name in task.artifacts:
return task.artifacts[return_name].get()
return task.artifacts[return_name].get(
deserialization_function=cls._singleton._artifact_deserialization_function
)
return task.get_parameters(cast=True)[CreateFromFunction.return_section + "/" + return_name]
return_w = [LazyEvalWrapper(
@ -3828,7 +3901,9 @@ class PipelineDecorator(PipelineController):
packages=None, # type: Optional[Union[str, Sequence[str]]]
repo=None, # type: Optional[str]
repo_branch=None, # type: Optional[str]
repo_commit=None # type: Optional[str]
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
):
# type: (...) -> Callable
"""
@ -3912,6 +3987,28 @@ class PipelineDecorator(PipelineController):
repo url and commit ID based on the locally cloned copy
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one
parameter of any type which is the object to be serialized. The function should return
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
artifacts uploaded by the pipeline will be serialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def serialize(obj):
import dill
return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
"""
def decorator_wrap(func):
@ -3955,7 +4052,9 @@ class PipelineDecorator(PipelineController):
packages=packages,
repo=repo,
repo_branch=repo_branch,
repo_commit=repo_commit
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
)
ret_val = func(**pipeline_kwargs)
LazyEvalWrapper.trigger_all_remote_references()
@ -4004,7 +4103,9 @@ class PipelineDecorator(PipelineController):
packages=packages,
repo=repo,
repo_branch=repo_branch,
repo_commit=repo_commit
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
)
a_pipeline._args_map = args_map or {}

View File

@ -481,6 +481,9 @@ from clearml.automation.controller import PipelineDecorator
task_template = """{header}
from clearml.utilities.proxy_object import get_basic_type
{artifact_serialization_function_source}
{artifact_deserialization_function_source}
{function_source}
@ -501,7 +504,7 @@ if __name__ == '__main__':
task_id, artifact_name = v.split('.', 1)
parent_task = Task.get_task(task_id=task_id)
if artifact_name in parent_task.artifacts:
kwargs[k] = parent_task.artifacts[artifact_name].get()
kwargs[k] = parent_task.artifacts[artifact_name].get(deserialization_function={artifact_deserialization_function_name})
else:
kwargs[k] = parent_task.get_parameters(cast=True)[return_section + '/' + artifact_name]
results = {function_name}(**kwargs)
@ -519,7 +522,8 @@ if __name__ == '__main__':
task.upload_artifact(
name=name,
artifact_object=artifact,
extension_name='.pkl' if isinstance(artifact, dict) else None
extension_name='.pkl' if isinstance(artifact, dict) else None,
serialization_function={artifact_serialization_function_name}
)
if parameters:
task._set_parameters(parameters, __parameters_types=parameters_types, __update=True)
@ -548,6 +552,8 @@ if __name__ == '__main__':
helper_functions=None, # type: Optional[Sequence[Callable]]
dry_run=False, # type: bool
task_template_header=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
_sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
):
@ -609,6 +615,27 @@ if __name__ == '__main__':
for the standalone function Task.
:param dry_run: If True, do not create the Task, but return a dict of the Task's definitions
:param task_template_header: A string placed at the top of the task's code
:param artifact_serialization_function: A serialization function that takes one
parameter of any type which is the object to be serialized. The function should return
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
artifacts uploaded by the pipeline will be serialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def serialize(obj):
import dill
return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
All relevant imports must be done in this function. For example:
.. code-block:: py
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
:param _sanitize_function: Sanitization function for the function string.
:param _sanitize_helper_functions: Sanitization function for the helper function string.
:return: Newly created Task object
@ -622,18 +649,28 @@ if __name__ == '__main__':
assert (not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict)))
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
function_name = str(a_function.__name__)
function_source = inspect.getsource(a_function)
if _sanitize_function:
function_source = _sanitize_function(function_source)
function_source = cls.__sanitize_remove_type_hints(function_source)
function_source, function_name = CreateFromFunction.__extract_function_information(
a_function, sanitize_function=_sanitize_function
)
# add helper functions on top.
for f in (helper_functions or []):
f_source = inspect.getsource(f)
if _sanitize_helper_functions:
f_source = _sanitize_helper_functions(f_source)
function_source = cls.__sanitize_remove_type_hints(f_source) + '\n\n' + function_source
helper_function_source, _ = CreateFromFunction.__extract_function_information(
f, sanitize_function=_sanitize_helper_functions
)
function_source = helper_function_source + "\n\n" + function_source
artifact_serialization_function_source, artifact_serialization_function_name = (
CreateFromFunction.__extract_function_information(artifact_serialization_function)
if artifact_serialization_function
else ("", "None")
)
artifact_deserialization_function_source, artifact_deserialization_function_name = (
CreateFromFunction.__extract_function_information(artifact_deserialization_function)
if artifact_deserialization_function
else ("", "None")
)
function_input_artifacts = function_input_artifacts or dict()
# verify artifact kwargs:
@ -686,6 +723,10 @@ if __name__ == '__main__':
function_name=function_name,
function_return=function_return,
return_section=cls.return_section,
artifact_serialization_function_source=artifact_serialization_function_source,
artifact_serialization_function_name=artifact_serialization_function_name,
artifact_deserialization_function_source=artifact_deserialization_function_source,
artifact_deserialization_function_name=artifact_deserialization_function_name
)
temp_dir = repo if repo and os.path.isdir(repo) else None
@ -784,3 +825,15 @@ if __name__ == '__main__':
except Exception:
# just in case we failed parsing.
return function_source
@staticmethod
def __extract_function_information(function, sanitize_function=None):
# type: (Callable, Optional[Callable]) -> (str, str)
function_name = str(function.__name__)
function_source = inspect.getsource(function)
if sanitize_function:
function_source = sanitize_function(function_source)
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
return function_source, function_name