Improve create_task_from_function default argument detection

This commit is contained in:
allegroai 2021-09-05 00:29:43 +03:00
parent 79ec0f67c3
commit fca1aac93f

View File

@ -469,11 +469,11 @@ if __name__ == '__main__':
kwargs = {function_kwargs} kwargs = {function_kwargs}
task.connect(kwargs, name='{kwargs_section}') task.connect(kwargs, name='{kwargs_section}')
function_input_artifacts = {function_input_artifacts} function_input_artifacts = {function_input_artifacts}
if function_input_artifacts: params = task.get_parameters() or dict()
task.connect(function_input_artifacts, name='{input_artifact_section}') for k, v in params.items():
for k, v in function_input_artifacts.items(): if not v or not k.startswith('{input_artifact_section}/'):
if not v:
continue continue
k = k.replace('{input_artifact_section}/', '', 1)
task_id, artifact_name = v.split('.', 1) task_id, artifact_name = v.split('.', 1)
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get() kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
results = {function_name}(**kwargs) results = {function_name}(**kwargs)
@ -501,6 +501,7 @@ if __name__ == '__main__':
docker_bash_setup_script=None, # type: Optional[str] docker_bash_setup_script=None, # type: Optional[str]
output_uri=None, # type: Optional[str] output_uri=None, # type: Optional[str]
dry_run=False, # type: bool dry_run=False, # type: bool
_sanitize_function=None, # type: Optional[Callable[[str], str]]
): ):
# type: (...) -> Optional[Dict, Task] # type: (...) -> Optional[Dict, Task]
""" """
@ -548,10 +549,14 @@ if __name__ == '__main__':
:param output_uri: Optional, set the Tasks's output_uri (Storage destination). :param output_uri: Optional, set the Tasks's output_uri (Storage destination).
examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/' examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/'
:param dry_run: If True do not create the Task, but return a dict of the Task's definitions :param dry_run: If True do not create the Task, but return a dict of the Task's definitions
:param _sanitize_function: Sanitization function for the function string.
:return: Newly created Task object :return: Newly created Task object
""" """
function_name = str(a_function.__name__) function_name = str(a_function.__name__)
function_source = inspect.getsource(a_function) function_source = inspect.getsource(a_function)
if _sanitize_function:
function_source = _sanitize_function(function_source)
function_input_artifacts = function_input_artifacts or dict() function_input_artifacts = function_input_artifacts or dict()
# verify artifact kwargs: # verify artifact kwargs:
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()): if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()):
@ -567,8 +572,9 @@ if __name__ == '__main__':
inspect_defaults_vals = inspect_args.defaults inspect_defaults_vals = inspect_args.defaults
inspect_defaults_args = inspect_args.args inspect_defaults_args = inspect_args.args
# adjust the defaults so they match the args (match from the end)
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args): if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
inspect_defaults_args = [a for a in inspect_defaults_args if a not in function_input_artifacts] inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals):]
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args): if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
getLogger().warning( getLogger().warning(