mirror of
https://github.com/clearml/clearml
synced 2025-06-03 19:37:48 +00:00
Fix PipelineDecorator.component()
ignores *args
and crashes with **kwargs
This commit is contained in:
parent
88a828cd78
commit
fa119672d8
@ -4691,8 +4691,13 @@ class PipelineDecorator(PipelineController):
|
||||
|
||||
# resolve all lazy objects if we have any:
|
||||
kwargs_artifacts = {}
|
||||
star_args_index = 0
|
||||
for i, v in enumerate(args):
|
||||
kwargs[inspect_func.args[i]] = v
|
||||
if not inspect_func.args or i >= len(inspect_func.args):
|
||||
kwargs[str(star_args_index)] = v
|
||||
star_args_index += 1
|
||||
else:
|
||||
kwargs[inspect_func.args[i]] = v
|
||||
|
||||
# We need to remember when a pipeline step's return value is evaluated by the pipeline
|
||||
# controller, but not when it's done here (as we would remember the step every time).
|
||||
|
@ -680,6 +680,7 @@ class CreateFromFunction(object):
|
||||
default_task_template_header = """from clearml import Task
|
||||
from clearml import TaskTypes
|
||||
from clearml.automation.controller import PipelineDecorator
|
||||
import inspect
|
||||
"""
|
||||
task_template = """{header}
|
||||
from clearml.utilities.proxy_object import get_basic_type
|
||||
@ -699,6 +700,11 @@ if __name__ == '__main__':
|
||||
task.connect(kwargs, name='{kwargs_section}')
|
||||
function_input_artifacts = {function_input_artifacts}
|
||||
params = task.get_parameters() or dict()
|
||||
argspec = inspect.getfullargspec({function_name})
|
||||
if argspec.varkw is not None or argspec.varargs is not None:
|
||||
for k, v in params.items():
|
||||
if k.startswith('{kwargs_section}/'):
|
||||
kwargs[k.replace('{kwargs_section}/', '', 1)] = v
|
||||
return_section = '{return_section}'
|
||||
for k, v in params.items():
|
||||
if not v or not k.startswith('{input_artifact_section}/'):
|
||||
@ -710,7 +716,15 @@ if __name__ == '__main__':
|
||||
kwargs[k] = parent_task.artifacts[artifact_name].get(deserialization_function={artifact_deserialization_function_name})
|
||||
else:
|
||||
kwargs[k] = parent_task.get_parameters(cast=True).get(return_section + '/' + artifact_name)
|
||||
results = {function_name}(**kwargs)
|
||||
if '0' in kwargs: # *args arguments are present
|
||||
pos_args = [kwargs.pop(arg, None) for arg in (argspec.args or [])]
|
||||
other_pos_args_index = 0
|
||||
while str(other_pos_args_index) in kwargs:
|
||||
pos_args.append(kwargs.pop(str(other_pos_args_index)))
|
||||
other_pos_args_index += 1
|
||||
results = {function_name}(*pos_args, **kwargs)
|
||||
else:
|
||||
results = {function_name}(**kwargs)
|
||||
result_names = {function_return}
|
||||
if result_names:
|
||||
if not isinstance(results, (tuple, list)) or len(result_names) == 1:
|
||||
|
Loading…
Reference in New Issue
Block a user