Fix pipeline argument becomes None if default value is not set

This commit is contained in:
allegroai 2022-10-14 10:09:35 +03:00
parent 452d3c5750
commit dd17fca080
4 changed files with 29 additions and 8 deletions

View File

@ -29,7 +29,7 @@ from ..model import BaseModel, OutputModel
from ..storage.util import hash_dict
from ..task import Task
from ..utilities.process.mp import leave_process
from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary, walk_nested_dict_tuple_list, verify_basic_type
from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary, walk_nested_dict_tuple_list
class PipelineController(object):
@ -701,7 +701,9 @@ class PipelineController(object):
function_input_artifacts = {}
# go over function_kwargs, split it into string and input artifacts
for k, v in function_kwargs.items():
if v is not None and self._step_ref_pattern.match(str(v)):
if v is None:
continue
if self._step_ref_pattern.match(str(v)):
# check for step artifacts
step, _, artifact = v[2:-1].partition('.')
if step in self._nodes and artifact in self._nodes[step].return_artifacts:
@ -711,9 +713,14 @@ class PipelineController(object):
# steps from tasks the _nodes is till empty, only after deserializing we will have the full DAG)
if self._task.running_locally():
self.__verify_step_reference(node=self.Node(name=name), step_ref_string=v)
elif not verify_basic_type(v):
elif not isinstance(v, (float, int, bool, six.string_types)):
function_input_artifacts[k] = "{}.{}.{}".format(self._task.id, name, k)
self._task.upload_artifact("{}.{}".format(name, k), artifact_object=v, wait_on_upload=True)
self._task.upload_artifact(
"{}.{}".format(name, k),
artifact_object=v,
wait_on_upload=True,
extension_name=".pkl" if isinstance(v, dict) else None,
)
function_kwargs = {k: v for k, v in function_kwargs.items() if k not in function_input_artifacts}
parameters = {"{}/{}".format(CreateFromFunction.kwargs_section, k): v for k, v in function_kwargs.items()}
@ -3500,7 +3507,10 @@ class PipelineDecorator(PipelineController):
# store the pipeline result of we have any:
if return_value and pipeline_result is not None:
a_pipeline._task.upload_artifact(
name=str(return_value), artifact_object=pipeline_result, wait_on_upload=True
name=str(return_value),
artifact_object=pipeline_result,
wait_on_upload=True,
extension_name=".pkl" if isinstance(pipeline_result, dict) else None,
)
# now we can stop the pipeline
@ -3625,13 +3635,17 @@ class PipelineDecorator(PipelineController):
x for x in cls._evaluated_return_values.get(tid, []) if x in leaves
]
for k, v in kwargs.items():
if v is None or verify_basic_type(v):
if v is None or isinstance(v, (float, int, bool, six.string_types)):
_node.parameters["{}/{}".format(CreateFromFunction.kwargs_section, k)] = v
else:
# we need to create an artifact
artifact_name = 'result_{}_{}'.format(re.sub(r'\W+', '', _node.name), k)
cls._singleton._task.upload_artifact(
name=artifact_name, artifact_object=v, wait_on_upload=True)
name=artifact_name,
artifact_object=v,
wait_on_upload=True,
extension_name=".pkl" if isinstance(v, dict) else None,
)
_node.parameters["{}/{}".format(CreateFromFunction.input_artifact_section, k)] = \
"{}.{}".format(cls._singleton._task.id, artifact_name)

View File

@ -500,7 +500,11 @@ if __name__ == '__main__':
if not isinstance(results, (tuple, list)) or len(result_names) == 1:
results = [results]
for name, artifact in zip(result_names, results):
task.upload_artifact(name=name, artifact_object=artifact)
task.upload_artifact(
name=name,
artifact_object=artifact,
extension_name='.pkl' if isinstance(artifact, dict) else None
)
"""
@classmethod

View File

@ -439,6 +439,8 @@ class Artifacts(object):
os.unlink(local_filename)
raise
artifact_type_data.content_type = mimetypes.guess_type(local_filename)[0]
elif extension_name == ".pkl":
store_as_pickle = True
elif np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.preview = preview or str(artifact_object.__repr__())

View File

@ -1930,6 +1930,7 @@ class Task(_Task):
- pandas.DataFrame - ``.csv.gz``, ``.parquet``, ``.feather``, ``.pickle`` (default ``.csv.gz``)
- numpy.ndarray - ``.npz``, ``.csv.gz`` (default ``.npz``)
- PIL.Image - whatever extensions PIL supports (default ``.png``)
- Any object - ``.pkl`` (if this extension is passed, the object will be pickled)
- In case the ``serialization_function`` argument is set - any extension is supported
:param Callable[Any, Union[bytes, bytearray]] serialization_function: A serialization function that takes one