Refactor CreateFromFunction

This commit is contained in:
allegroai 2023-01-23 14:04:13 +02:00
parent 8cb4ac2acb
commit 2120fc85f5
2 changed files with 13 additions and 4 deletions

View File

@ -63,6 +63,7 @@ class PipelineController(object):
_retries = {} # Node.name: int
_retries_callbacks = {} # Node.name: Callable[[PipelineController, PipelineController.Node, int], bool] # noqa
_final_failure = {} # Node.name: bool
_task_template_header = CreateFromFunction.default_task_template_header
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
@ -1170,6 +1171,7 @@ class PipelineController(object):
output_uri=None,
helper_functions=helper_functions,
dry_run=True,
task_template_header=self._task_template_header
)
return task_definition
@ -3199,7 +3201,7 @@ class PipelineDecorator(PipelineController):
function, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, task_type, repo, branch, commit,
helper_functions,
helper_functions
):
def sanitize(function_source):
matched = re.match(r"[\s]*@[\w]*.component[\s\\]*\(", function_source)
@ -3240,7 +3242,8 @@ class PipelineDecorator(PipelineController):
output_uri=None,
helper_functions=helper_functions,
dry_run=True,
_sanitize_function=sanitize,
task_template_header=self._task_template_header,
_sanitize_function=sanitize
)
return task_definition

View File

@ -474,8 +474,11 @@ class CreateFromFunction(object):
kwargs_section = "kwargs"
return_section = "return"
input_artifact_section = "kwargs_artifacts"
task_template = """from clearml import Task, TaskTypes
default_task_template_header = """from clearml import Task
from clearml import TaskTypes
from clearml.automation.controller import PipelineDecorator
"""
task_template = """{header}
from clearml.utilities.proxy_object import get_basic_type
@ -509,7 +512,7 @@ if __name__ == '__main__':
parameters = dict()
parameters_types = dict()
for name, artifact in zip(result_names, results):
if isinstance(artifact, (float, int, bool, str)):
if type(artifact) in (float, int, bool, str):
parameters[return_section + '/' + name] = artifact
parameters_types[return_section + '/' + name] = get_basic_type(artifact)
else:
@ -544,6 +547,7 @@ if __name__ == '__main__':
output_uri=None, # type: Optional[str]
helper_functions=None, # type: Optional[Sequence[Callable]]
dry_run=False, # type: bool
task_template_header=None, # type: Optional[str]
_sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
):
@ -604,6 +608,7 @@ if __name__ == '__main__':
:param helper_functions: Optional, a list of helper functions to make available
for the standalone function Task.
:param dry_run: If True, do not create the Task, but return a dict of the Task's definitions
:param task_template_header: A string placed at the top of the task's code
:param _sanitize_function: Sanitization function for the function string.
:param _sanitize_helper_functions: Sanitization function for the helper function string.
:return: Newly created Task object
@ -670,6 +675,7 @@ if __name__ == '__main__':
if inspect_args.annotations[k] in supported_types}
task_template = cls.task_template.format(
header=task_template_header or cls.default_task_template_header,
auto_connect_frameworks=auto_connect_frameworks,
auto_connect_arg_parser=auto_connect_arg_parser,
kwargs_section=cls.kwargs_section,