Refactor pipeline code

This commit is contained in:
allegroai 2023-03-27 13:43:52 +03:00
parent 4a91843559
commit 2f4f11aadb

View File

@ -5,6 +5,7 @@ import json
import os
import re
import six
import warnings
from copy import copy, deepcopy
from datetime import datetime
from logging import getLogger
@ -31,6 +32,7 @@ from ..storage.util import hash_dict
from ..task import Task
from ..utilities.process.mp import leave_process
from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary, walk_nested_dict_tuple_list
from ..utilities.version import Version
class PipelineController(object):
@ -66,6 +68,7 @@ class PipelineController(object):
_status_change_callbacks = {} # Node.name: Callable[PipelineController, PipelineController.Node, str]
_final_failure = {} # Node.name: bool
_task_template_header = CreateFromFunction.default_task_template_header
_default_pipeline_version = "1.0.0"
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
@ -132,11 +135,11 @@ class PipelineController(object):
self,
name, # type: str
project, # type: str
version, # type: str
version=None, # type: Optional[str]
pool_frequency=0.2, # type: float
add_pipeline_tags=False, # type: bool
target_project=True, # type: Optional[Union[str, bool]]
auto_version_bump=True, # type: bool
auto_version_bump=None, # type: Optional[bool]
abort_on_failure=False, # type: bool
add_run_number=True, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
@ -156,14 +159,16 @@ class PipelineController(object):
:param name: Provide pipeline name (if main Task exists it overrides its name)
:param project: Provide project storing the pipeline (if main Task exists it overrides its project)
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'
:param version: Pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
If not set, find the latest version of the pipeline and increment it. If no such version is found,
default to '1.0.0'
:param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
steps (Tasks) created by this pipeline.
:param str target_project: If provided, all pipeline steps are cloned into the target project.
If True, pipeline steps are stored into the pipeline project
:param bool auto_version_bump: If True (default), if the same pipeline version already exists
:param bool auto_version_bump: (Deprecated) If True, if the same pipeline version already exists
(with any difference from the current one), the current pipeline version will be bumped to a new version
version bump examples: 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11 etc.
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
@ -205,6 +210,7 @@ class PipelineController(object):
Example remote url: 'https://github.com/user/repo.git'
Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one
@ -229,16 +235,18 @@ class PipelineController(object):
import dill
return dill.loads(bytes_)
"""
if auto_version_bump is not None:
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
self._nodes = {}
self._running_nodes = []
self._start_time = None
self._pipeline_time_limit = None
self._default_execution_queue = None
self._version = str(version).strip()
if not self._version or not all(i and i.isnumeric() for i in self._version.split('.')):
self._version = str(version).strip() if version else None
if self._version and not Version.is_valid_version_string(self._version):
raise ValueError(
"Pipeline version has to be in a semantic version form, "
"examples: version='1.0.1', version='1.2', version='23'")
"Setting non-semantic dataset version '{}'".format(self._version)
)
self._pool_frequency = pool_frequency * 60.
self._thread = None
self._pipeline_args = dict()
@ -256,7 +264,6 @@ class PipelineController(object):
self._step_ref_pattern = re.compile(self._step_pattern)
self._reporting_lock = RLock()
self._pipeline_task_status_failed = None
self._auto_version_bump = bool(auto_version_bump)
self._mock_execution = False # used for nested pipelines (eager execution)
self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17"))
self._last_progress_update_time = 0
@ -271,6 +278,12 @@ class PipelineController(object):
parent_project = None
project_name = project or 'Pipelines'
# if user disabled the auto-repo, we force local script storage (repo="" or repo=False)
set_force_local_repo = False
if Task.running_locally() and repo is not None and not repo:
Task.force_store_standalone_script(force=True)
set_force_local_repo = True
self._task = Task.init(
project_name=project_name,
task_name=task_name,
@ -278,6 +291,13 @@ class PipelineController(object):
auto_resource_monitoring=False,
reuse_last_task_id=False
)
# if user disabled the auto-repo, set it back to False (just in case)
if set_force_local_repo:
# noinspection PyProtectedMember
self._task._wait_for_repo_detection(timeout=300.)
Task.force_store_standalone_script(force=False)
# make sure project is hidden
if self._pipeline_as_sub_project:
get_or_create_project(
@ -287,7 +307,6 @@ class PipelineController(object):
project_id=self._task.project, system_tags=self._project_system_tags)
self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag])
self._task.set_user_properties(version=self._version)
self._task.set_base_docker(
docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script
)
@ -1259,7 +1278,7 @@ class PipelineController(object):
:param name: String name of the parameter.
:param default: Default value to be put as the default value (can be later changed in the UI)
:param description: String description of the parameter and its usage in the pipeline
:param param_type: Optional, parameter type information (to used as hint for casting and description)
:param param_type: Optional, parameter type information (to be used as hint for casting and description)
"""
self._pipeline_args[str(name)] = default
if description:
@ -1443,7 +1462,7 @@ class PipelineController(object):
# make sure we have a unique version number (auto bump version if needed)
# only needed when manually (from code) creating pipelines
self._verify_pipeline_version()
self._handle_pipeline_version()
# noinspection PyProtectedMember
pipeline_hash = self._get_task_hash()
@ -1451,6 +1470,7 @@ class PipelineController(object):
# noinspection PyProtectedMember
self._task._set_runtime_properties({
self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version),
"version": self._version
})
else:
self._task.connect_configuration(pipeline_dag, name=self._config_section)
@ -1482,74 +1502,25 @@ class PipelineController(object):
return params, pipeline_dag
def _verify_pipeline_version(self):
# if no version bump needed, just set the property
if not self._auto_version_bump:
self._task.set_user_properties(version=self._version)
return
# check if pipeline version exists, if it does increase version
pipeline_hash = self._get_task_hash()
# noinspection PyProtectedMember
existing_tasks = Task._query_tasks(
project=[self._task.project], task_name=exact_match_regex(self._task.name),
type=[str(self._task.task_type)],
system_tags=["__$all", self._tag, "__$not", Task.archived_tag],
_all_=dict(fields=['runtime.{}'.format(self._runtime_property_hash)],
pattern=":{}".format(self._version)),
only_fields=['id', 'runtime'],
)
if existing_tasks:
# check if hash match the current version.
matched = True
for t in existing_tasks:
h, _, v = t.runtime.get(self._runtime_property_hash, '').partition(':')
if v == self._version:
matched = bool(h == pipeline_hash)
break
# if hash did not match, look for the highest version
if not matched:
# noinspection PyProtectedMember
existing_tasks = Task._query_tasks(
project=[self._task.project], task_name=exact_match_regex(self._task.name),
type=[str(self._task.task_type)],
system_tags=["__$all", self._tag, "__$not", Task.archived_tag],
only_fields=['id', 'hyperparams', 'runtime'],
def _handle_pipeline_version(self):
if not self._version:
# noinspection PyProtectedMember
self._version = self._task._get_runtime_properties().get("version")
if not self._version:
previous_pipeline_tasks = Task._query_tasks(
project=[self._task.project],
fetch_only_first_page=True,
only_fields=["runtime.version"],
order_by=["-last_update"],
system_tags=[self._tag],
search_hidden=True,
_allow_extra_fields_=True
)
found_match_version = False
existing_versions = set([self._version]) # noqa
for t in existing_tasks:
# exclude ourselves
if t.id == self._task.id:
continue
if not t.hyperparams:
continue
v = t.hyperparams.get('properties', {}).get('version')
if v:
existing_versions.add(v.value)
if t.runtime:
h, _, _ = t.runtime.get(self._runtime_property_hash, '').partition(':')
if h == pipeline_hash:
self._version = v.value
found_match_version = True
break
# match to the version we found:
if found_match_version:
getLogger('clearml.automation.controller').info(
'Existing Pipeline found, matching version to: {}'.format(self._version))
else:
# if we did not find a matched pipeline version, get the max one and bump the version by 1
while True:
v = self._version.split('.')
self._version = '.'.join(v[:-1] + [str(int(v[-1]) + 1)])
if self._version not in existing_versions:
break
getLogger('clearml.automation.controller').info(
'No matching Pipelines found, bump new version to: {}'.format(self._version))
self._task.set_user_properties(version=self._version)
for previous_pipeline_task in previous_pipeline_tasks:
if previous_pipeline_task.runtime.get("version"):
self._version = str(Version(previous_pipeline_task.runtime.get("version")).get_next_version())
break
self._version = self._version or self._default_pipeline_version
def _get_task_hash(self):
params_override = dict(**(self._task.get_parameters() or {}))
@ -3054,7 +3025,8 @@ class PipelineController(object):
name=artifact_name,
artifact_object=artifact_object,
wait_on_upload=True,
extension_name=".pkl" if isinstance(artifact_object, dict) else None,
extension_name=".pkl" if isinstance(artifact_object, dict) and not self._artifact_serialization_function else None,
serialization_function=self._artifact_serialization_function
)
@ -3075,7 +3047,7 @@ class PipelineDecorator(PipelineController):
self,
name, # type: str
project, # type: str
version, # type: str
version=None, # type: Optional[str]
pool_frequency=0.2, # type: float
add_pipeline_tags=False, # type: bool
target_project=None, # type: Optional[str]
@ -3098,8 +3070,10 @@ class PipelineDecorator(PipelineController):
:param name: Provide pipeline name (if main Task exists it overrides its name)
:param project: Provide project storing the pipeline (if main Task exists it overrides its project)
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'
:param version: Pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
If not set, find the latest version of the pipeline and increment it. If no such version is found,
default to '1.0.0'
:param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
steps (Tasks) created by this pipeline.
@ -3143,6 +3117,7 @@ class PipelineDecorator(PipelineController):
Example remote url: 'https://github.com/user/repo.git'
Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one
@ -3156,7 +3131,6 @@ class PipelineDecorator(PipelineController):
def serialize(obj):
import dill
return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
@ -3906,7 +3880,7 @@ class PipelineDecorator(PipelineController):
_func=None, *, # noqa
name, # type: str
project, # type: str
version, # type: str
version=None, # type: Optional[str]
return_value=None, # type: Optional[str]
default_queue=None, # type: Optional[str]
pool_frequency=0.2, # type: float
@ -3935,8 +3909,10 @@ class PipelineDecorator(PipelineController):
:param name: Provide pipeline name (if main Task exists it overrides its name)
:param project: Provide project storing the pipeline (if main Task exists it overrides its project)
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'
:param version: Pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
If not set, find the latest version of the pipeline and increment it. If no such version is found,
default to '1.0.0'
:param return_value: Optional, Provide an artifact name to store the pipeline function return object
Notice, If not provided the pipeline will not store the pipeline function return value.
:param default_queue: default pipeline step queue
@ -4009,6 +3985,7 @@ class PipelineDecorator(PipelineController):
Example remote url: 'https://github.com/user/repo.git'
Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one
@ -4022,7 +3999,6 @@ class PipelineDecorator(PipelineController):
def serialize(obj):
import dill
return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.