mirror of
https://github.com/clearml/clearml
synced 2025-04-21 23:04:42 +00:00
Improve create_task_from_function default argument detection
This commit is contained in:
parent
79ec0f67c3
commit
fca1aac93f
@ -469,17 +469,17 @@ if __name__ == '__main__':
|
||||
kwargs = {function_kwargs}
|
||||
task.connect(kwargs, name='{kwargs_section}')
|
||||
function_input_artifacts = {function_input_artifacts}
|
||||
if function_input_artifacts:
|
||||
task.connect(function_input_artifacts, name='{input_artifact_section}')
|
||||
for k, v in function_input_artifacts.items():
|
||||
if not v:
|
||||
continue
|
||||
task_id, artifact_name = v.split('.', 1)
|
||||
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
|
||||
params = task.get_parameters() or dict()
|
||||
for k, v in params.items():
|
||||
if not v or not k.startswith('{input_artifact_section}/'):
|
||||
continue
|
||||
k = k.replace('{input_artifact_section}/', '', 1)
|
||||
task_id, artifact_name = v.split('.', 1)
|
||||
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
|
||||
results = {function_name}(**kwargs)
|
||||
result_names = {function_return}
|
||||
if result_names:
|
||||
if not isinstance(results, (tuple, list)) or (len(result_names)==1 and len(results) != 1):
|
||||
if not isinstance(results, (tuple, list)) or (len(result_names) == 1 and len(results) != 1):
|
||||
results = [results]
|
||||
for name, artifact in zip(result_names, results):
|
||||
task.upload_artifact(name=name, artifact_object=artifact)
|
||||
@ -501,6 +501,7 @@ if __name__ == '__main__':
|
||||
docker_bash_setup_script=None, # type: Optional[str]
|
||||
output_uri=None, # type: Optional[str]
|
||||
dry_run=False, # type: bool
|
||||
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||
):
|
||||
# type: (...) -> Optional[Dict, Task]
|
||||
"""
|
||||
@ -548,10 +549,14 @@ if __name__ == '__main__':
|
||||
:param output_uri: Optional, set the Tasks's output_uri (Storage destination).
|
||||
examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/'
|
||||
:param dry_run: If True do not create the Task, but return a dict of the Task's definitions
|
||||
:param _sanitize_function: Sanitization function for the function string.
|
||||
:return: Newly created Task object
|
||||
"""
|
||||
function_name = str(a_function.__name__)
|
||||
function_source = inspect.getsource(a_function)
|
||||
if _sanitize_function:
|
||||
function_source = _sanitize_function(function_source)
|
||||
|
||||
function_input_artifacts = function_input_artifacts or dict()
|
||||
# verify artifact kwargs:
|
||||
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()):
|
||||
@ -567,8 +572,9 @@ if __name__ == '__main__':
|
||||
inspect_defaults_vals = inspect_args.defaults
|
||||
inspect_defaults_args = inspect_args.args
|
||||
|
||||
# adjust the defaults so they match the args (match from the end)
|
||||
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
|
||||
inspect_defaults_args = [a for a in inspect_defaults_args if a not in function_input_artifacts]
|
||||
inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals):]
|
||||
|
||||
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
|
||||
getLogger().warning(
|
||||
|
Loading…
Reference in New Issue
Block a user