diff --git a/clearml/automation/controller.py b/clearml/automation/controller.py index cc7864ca..37f75f4e 100644 --- a/clearml/automation/controller.py +++ b/clearml/automation/controller.py @@ -146,6 +146,8 @@ class PipelineController(object): repo=None, # type: Optional[str] repo_branch=None, # type: Optional[str] repo_commit=None # type: Optional[str] + artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] + artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]] ): # type: (...) -> None """ @@ -204,6 +206,28 @@ class PipelineController(object): repo url and commit ID based on the locally cloned copy :param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used) :param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used) + :param artifact_serialization_function: A serialization function that takes one + parameter of any type which is the object to be serialized. The function should return + a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return + artifacts uploaded by the pipeline will be serialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def serialize(obj): + import dill + return dill.dumps(obj) + + :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, + which represents the serialized object. This function should return the deserialized object. + All parameter/return artifacts fetched by the pipeline will be deserialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def deserialize(bytes_): + import dill + return dill.loads(bytes_) """ self._nodes = {} self._running_nodes = [] @@ -236,6 +260,8 @@ class PipelineController(object): self._mock_execution = False # used for nested pipelines (eager execution) self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17")) self._last_progress_update_time = 0 + self._artifact_serialization_function = artifact_serialization_function + self._artifact_deserialization_function = artifact_deserialization_function if not self._task: task_name = name or project or '{}'.format(datetime.now()) if self._pipeline_as_sub_project: @@ -990,6 +1016,7 @@ class PipelineController(object): auto_pickle=True, # type: bool preview=None, # type: Any wait_on_upload=False, # type: bool + serialization_function=None # type: Optional[Callable[[Any], Union[bytes, bytearray]]] ): # type: (...) -> bool """ @@ -1032,6 +1059,13 @@ class PipelineController(object): :param bool wait_on_upload: Whether the upload should be synchronous, forcing the upload to complete before continuing. + :param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one + parameter of any type which is the object to be serialized. The function should return + a `bytes` or `bytearray` object, which represents the serialized object. Note that the object will be + immediately serialized using this function, thus other serialization methods will not be used + (e.g. `pandas.DataFrame.to_csv`), even if possible. To deserialize this artifact when getting + it using the `Artifact.get` method, use its `deserialization_function` argument. + :return: The status of the upload. - ``True`` - Upload succeeded. @@ -1041,8 +1075,15 @@ class PipelineController(object): """ task = cls._get_pipeline_task() return task.upload_artifact( - name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload, - auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload) + name=name, + artifact_object=artifact_object, + metadata=metadata, + delete_after_upload=delete_after_upload, + auto_pickle=auto_pickle, + preview=preview, + wait_on_upload=wait_on_upload, + serialization_function=serialization_function + ) def stop(self, timeout=None, mark_failed=False, mark_aborted=False): # type: (Optional[float], bool, bool) -> () @@ -1236,7 +1277,9 @@ class PipelineController(object): output_uri=None, helper_functions=helper_functions, dry_run=True, - task_template_header=self._task_template_header + task_template_header=self._task_template_header, + artifact_serialization_function=self._artifact_serialization_function, + artifact_deserialization_function=self._artifact_deserialization_function ) return task_definition @@ -3021,7 +3064,9 @@ class PipelineDecorator(PipelineController): packages=None, # type: Optional[Union[str, Sequence[str]]] repo=None, # type: Optional[str] repo_branch=None, # type: Optional[str] - repo_commit=None # type: Optional[str] + repo_commit=None, # type: Optional[str] + artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] + artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]] ): # type: (...) -> () """ @@ -3076,6 +3121,28 @@ class PipelineDecorator(PipelineController): repo url and commit ID based on the locally cloned copy :param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used) :param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used) + :param artifact_serialization_function: A serialization function that takes one + parameter of any type which is the object to be serialized. The function should return + a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return + artifacts uploaded by the pipeline will be serialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def serialize(obj): + import dill + return dill.dumps(obj) + + :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, + which represents the serialized object. This function should return the deserialized object. + All parameter/return artifacts fetched by the pipeline will be deserialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def deserialize(bytes_): + import dill + return dill.loads(bytes_) """ super(PipelineDecorator, self).__init__( name=name, @@ -3093,7 +3160,9 @@ class PipelineDecorator(PipelineController): packages=packages, repo=repo, repo_branch=repo_branch, - repo_commit=repo_commit + repo_commit=repo_commit, + artifact_serialization_function=artifact_serialization_function, + artifact_deserialization_function=artifact_deserialization_function ) # if we are in eager execution, make sure parent class knows it @@ -3356,7 +3425,9 @@ class PipelineDecorator(PipelineController): helper_functions=helper_functions, dry_run=True, task_template_header=self._task_template_header, - _sanitize_function=sanitize + _sanitize_function=sanitize, + artifact_serialization_function=self._artifact_serialization_function, + artifact_deserialization_function=self._artifact_deserialization_function ) return task_definition @@ -3785,7 +3856,9 @@ class PipelineDecorator(PipelineController): task = Task.get_task(_node.job.task_id()) if return_name in task.artifacts: - return task.artifacts[return_name].get() + return task.artifacts[return_name].get( + deserialization_function=cls._singleton._artifact_deserialization_function + ) return task.get_parameters(cast=True)[CreateFromFunction.return_section + "/" + return_name] return_w = [LazyEvalWrapper( @@ -3828,7 +3901,9 @@ class PipelineDecorator(PipelineController): packages=None, # type: Optional[Union[str, Sequence[str]]] repo=None, # type: Optional[str] repo_branch=None, # type: Optional[str] - repo_commit=None # type: Optional[str] + repo_commit=None, # type: Optional[str] + artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] + artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]] ): # type: (...) -> Callable """ @@ -3912,6 +3987,28 @@ class PipelineDecorator(PipelineController): repo url and commit ID based on the locally cloned copy :param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used) :param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used) + :param artifact_serialization_function: A serialization function that takes one + parameter of any type which is the object to be serialized. The function should return + a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return + artifacts uploaded by the pipeline will be serialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def serialize(obj): + import dill + return dill.dumps(obj) + + :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, + which represents the serialized object. This function should return the deserialized object. + All parameter/return artifacts fetched by the pipeline will be deserialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def deserialize(bytes_): + import dill + return dill.loads(bytes_) """ def decorator_wrap(func): @@ -3955,7 +4052,9 @@ class PipelineDecorator(PipelineController): packages=packages, repo=repo, repo_branch=repo_branch, - repo_commit=repo_commit + repo_commit=repo_commit, + artifact_serialization_function=artifact_serialization_function, + artifact_deserialization_function=artifact_deserialization_function ) ret_val = func(**pipeline_kwargs) LazyEvalWrapper.trigger_all_remote_references() @@ -4004,7 +4103,9 @@ class PipelineDecorator(PipelineController): packages=packages, repo=repo, repo_branch=repo_branch, - repo_commit=repo_commit + repo_commit=repo_commit, + artifact_serialization_function=artifact_serialization_function, + artifact_deserialization_function=artifact_deserialization_function ) a_pipeline._args_map = args_map or {} diff --git a/clearml/backend_interface/task/populate.py b/clearml/backend_interface/task/populate.py index d21055d4..d45c591b 100644 --- a/clearml/backend_interface/task/populate.py +++ b/clearml/backend_interface/task/populate.py @@ -481,6 +481,9 @@ from clearml.automation.controller import PipelineDecorator task_template = """{header} from clearml.utilities.proxy_object import get_basic_type +{artifact_serialization_function_source} + +{artifact_deserialization_function_source} {function_source} @@ -501,7 +504,7 @@ if __name__ == '__main__': task_id, artifact_name = v.split('.', 1) parent_task = Task.get_task(task_id=task_id) if artifact_name in parent_task.artifacts: - kwargs[k] = parent_task.artifacts[artifact_name].get() + kwargs[k] = parent_task.artifacts[artifact_name].get(deserialization_function={artifact_deserialization_function_name}) else: kwargs[k] = parent_task.get_parameters(cast=True)[return_section + '/' + artifact_name] results = {function_name}(**kwargs) @@ -519,7 +522,8 @@ if __name__ == '__main__': task.upload_artifact( name=name, artifact_object=artifact, - extension_name='.pkl' if isinstance(artifact, dict) else None + extension_name='.pkl' if isinstance(artifact, dict) else None, + serialization_function={artifact_serialization_function_name} ) if parameters: task._set_parameters(parameters, __parameters_types=parameters_types, __update=True) @@ -548,6 +552,8 @@ if __name__ == '__main__': helper_functions=None, # type: Optional[Sequence[Callable]] dry_run=False, # type: bool task_template_header=None, # type: Optional[str] + artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] + artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]] _sanitize_function=None, # type: Optional[Callable[[str], str]] _sanitize_helper_functions=None, # type: Optional[Callable[[str], str]] ): @@ -609,6 +615,27 @@ if __name__ == '__main__': for the standalone function Task. :param dry_run: If True, do not create the Task, but return a dict of the Task's definitions :param task_template_header: A string placed at the top of the task's code + :param artifact_serialization_function: A serialization function that takes one + parameter of any type which is the object to be serialized. The function should return + a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return + artifacts uploaded by the pipeline will be serialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def serialize(obj): + import dill + return dill.dumps(obj) + :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, + which represents the serialized object. This function should return the deserialized object. + All parameter/return artifacts fetched by the pipeline will be deserialized using this function. + All relevant imports must be done in this function. For example: + + .. code-block:: py + + def deserialize(bytes_): + import dill + return dill.loads(bytes_) :param _sanitize_function: Sanitization function for the function string. :param _sanitize_helper_functions: Sanitization function for the helper function string. :return: Newly created Task object @@ -622,18 +649,28 @@ if __name__ == '__main__': assert (not auto_connect_frameworks or isinstance(auto_connect_frameworks, (bool, dict))) assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict))) - function_name = str(a_function.__name__) - function_source = inspect.getsource(a_function) - if _sanitize_function: - function_source = _sanitize_function(function_source) - function_source = cls.__sanitize_remove_type_hints(function_source) + function_source, function_name = CreateFromFunction.__extract_function_information( + a_function, sanitize_function=_sanitize_function + ) # add helper functions on top. for f in (helper_functions or []): - f_source = inspect.getsource(f) - if _sanitize_helper_functions: - f_source = _sanitize_helper_functions(f_source) - function_source = cls.__sanitize_remove_type_hints(f_source) + '\n\n' + function_source + helper_function_source, _ = CreateFromFunction.__extract_function_information( + f, sanitize_function=_sanitize_helper_functions + ) + function_source = helper_function_source + "\n\n" + function_source + + artifact_serialization_function_source, artifact_serialization_function_name = ( + CreateFromFunction.__extract_function_information(artifact_serialization_function) + if artifact_serialization_function + else ("", "None") + ) + + artifact_deserialization_function_source, artifact_deserialization_function_name = ( + CreateFromFunction.__extract_function_information(artifact_deserialization_function) + if artifact_deserialization_function + else ("", "None") + ) function_input_artifacts = function_input_artifacts or dict() # verify artifact kwargs: @@ -686,6 +723,10 @@ if __name__ == '__main__': function_name=function_name, function_return=function_return, return_section=cls.return_section, + artifact_serialization_function_source=artifact_serialization_function_source, + artifact_serialization_function_name=artifact_serialization_function_name, + artifact_deserialization_function_source=artifact_deserialization_function_source, + artifact_deserialization_function_name=artifact_deserialization_function_name ) temp_dir = repo if repo and os.path.isdir(repo) else None @@ -784,3 +825,15 @@ if __name__ == '__main__': except Exception: # just in case we failed parsing. return function_source + + @staticmethod + def __extract_function_information(function, sanitize_function=None): + # type: (Callable, Optional[Callable]) -> (str, str) + function_name = str(function.__name__) + function_source = inspect.getsource(function) + if sanitize_function: + function_source = sanitize_function(function_source) + + function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source) + + return function_source, function_name \ No newline at end of file