Improve create_task_from_function default argument detection

This commit is contained in:
allegroai 2021-09-05 00:29:43 +03:00
parent 79ec0f67c3
commit fca1aac93f

View File

@ -469,17 +469,17 @@ if __name__ == '__main__':
kwargs = {function_kwargs}
task.connect(kwargs, name='{kwargs_section}')
function_input_artifacts = {function_input_artifacts}
if function_input_artifacts:
task.connect(function_input_artifacts, name='{input_artifact_section}')
for k, v in function_input_artifacts.items():
if not v:
continue
task_id, artifact_name = v.split('.', 1)
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
params = task.get_parameters() or dict()
for k, v in params.items():
if not v or not k.startswith('{input_artifact_section}/'):
continue
k = k.replace('{input_artifact_section}/', '', 1)
task_id, artifact_name = v.split('.', 1)
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
results = {function_name}(**kwargs)
result_names = {function_return}
if result_names:
if not isinstance(results, (tuple, list)) or (len(result_names)==1 and len(results) != 1):
if not isinstance(results, (tuple, list)) or (len(result_names) == 1 and len(results) != 1):
results = [results]
for name, artifact in zip(result_names, results):
task.upload_artifact(name=name, artifact_object=artifact)
@ -501,6 +501,7 @@ if __name__ == '__main__':
docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str]
dry_run=False, # type: bool
_sanitize_function=None, # type: Optional[Callable[[str], str]]
):
# type: (...) -> Optional[Dict, Task]
"""
@ -548,10 +549,14 @@ if __name__ == '__main__':
:param output_uri: Optional, set the Tasks's output_uri (Storage destination).
examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/'
:param dry_run: If True do not create the Task, but return a dict of the Task's definitions
:param _sanitize_function: Sanitization function for the function string.
:return: Newly created Task object
"""
function_name = str(a_function.__name__)
function_source = inspect.getsource(a_function)
if _sanitize_function:
function_source = _sanitize_function(function_source)
function_input_artifacts = function_input_artifacts or dict()
# verify artifact kwargs:
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()):
@ -567,8 +572,9 @@ if __name__ == '__main__':
inspect_defaults_vals = inspect_args.defaults
inspect_defaults_args = inspect_args.args
# adjust the defaults so they match the args (match from the end)
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
inspect_defaults_args = [a for a in inspect_defaults_args if a not in function_input_artifacts]
inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals):]
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
getLogger().warning(