mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Improve create_task_from_function default argument detection
This commit is contained in:
parent
79ec0f67c3
commit
fca1aac93f
@ -469,11 +469,11 @@ if __name__ == '__main__':
|
|||||||
kwargs = {function_kwargs}
|
kwargs = {function_kwargs}
|
||||||
task.connect(kwargs, name='{kwargs_section}')
|
task.connect(kwargs, name='{kwargs_section}')
|
||||||
function_input_artifacts = {function_input_artifacts}
|
function_input_artifacts = {function_input_artifacts}
|
||||||
if function_input_artifacts:
|
params = task.get_parameters() or dict()
|
||||||
task.connect(function_input_artifacts, name='{input_artifact_section}')
|
for k, v in params.items():
|
||||||
for k, v in function_input_artifacts.items():
|
if not v or not k.startswith('{input_artifact_section}/'):
|
||||||
if not v:
|
|
||||||
continue
|
continue
|
||||||
|
k = k.replace('{input_artifact_section}/', '', 1)
|
||||||
task_id, artifact_name = v.split('.', 1)
|
task_id, artifact_name = v.split('.', 1)
|
||||||
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
|
kwargs[k] = Task.get_task(task_id=task_id).artifacts[artifact_name].get()
|
||||||
results = {function_name}(**kwargs)
|
results = {function_name}(**kwargs)
|
||||||
@ -501,6 +501,7 @@ if __name__ == '__main__':
|
|||||||
docker_bash_setup_script=None, # type: Optional[str]
|
docker_bash_setup_script=None, # type: Optional[str]
|
||||||
output_uri=None, # type: Optional[str]
|
output_uri=None, # type: Optional[str]
|
||||||
dry_run=False, # type: bool
|
dry_run=False, # type: bool
|
||||||
|
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||||
):
|
):
|
||||||
# type: (...) -> Optional[Dict, Task]
|
# type: (...) -> Optional[Dict, Task]
|
||||||
"""
|
"""
|
||||||
@ -548,10 +549,14 @@ if __name__ == '__main__':
|
|||||||
:param output_uri: Optional, set the Tasks's output_uri (Storage destination).
|
:param output_uri: Optional, set the Tasks's output_uri (Storage destination).
|
||||||
examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/'
|
examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/'
|
||||||
:param dry_run: If True do not create the Task, but return a dict of the Task's definitions
|
:param dry_run: If True do not create the Task, but return a dict of the Task's definitions
|
||||||
|
:param _sanitize_function: Sanitization function for the function string.
|
||||||
:return: Newly created Task object
|
:return: Newly created Task object
|
||||||
"""
|
"""
|
||||||
function_name = str(a_function.__name__)
|
function_name = str(a_function.__name__)
|
||||||
function_source = inspect.getsource(a_function)
|
function_source = inspect.getsource(a_function)
|
||||||
|
if _sanitize_function:
|
||||||
|
function_source = _sanitize_function(function_source)
|
||||||
|
|
||||||
function_input_artifacts = function_input_artifacts or dict()
|
function_input_artifacts = function_input_artifacts or dict()
|
||||||
# verify artifact kwargs:
|
# verify artifact kwargs:
|
||||||
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()):
|
if not all(len(v.split('.', 1)) == 2 for v in function_input_artifacts.values()):
|
||||||
@ -567,8 +572,9 @@ if __name__ == '__main__':
|
|||||||
inspect_defaults_vals = inspect_args.defaults
|
inspect_defaults_vals = inspect_args.defaults
|
||||||
inspect_defaults_args = inspect_args.args
|
inspect_defaults_args = inspect_args.args
|
||||||
|
|
||||||
|
# adjust the defaults so they match the args (match from the end)
|
||||||
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
|
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
|
||||||
inspect_defaults_args = [a for a in inspect_defaults_args if a not in function_input_artifacts]
|
inspect_defaults_args = inspect_defaults_args[-len(inspect_defaults_vals):]
|
||||||
|
|
||||||
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
|
if inspect_defaults_vals and len(inspect_defaults_vals) != len(inspect_defaults_args):
|
||||||
getLogger().warning(
|
getLogger().warning(
|
||||||
|
Loading…
Reference in New Issue
Block a user