Refactor pipeline code

This commit is contained in:
allegroai 2023-03-27 13:43:52 +03:00
parent 4a91843559
commit 2f4f11aadb

View File

@ -5,6 +5,7 @@ import json
import os import os
import re import re
import six import six
import warnings
from copy import copy, deepcopy from copy import copy, deepcopy
from datetime import datetime from datetime import datetime
from logging import getLogger from logging import getLogger
@ -31,6 +32,7 @@ from ..storage.util import hash_dict
from ..task import Task from ..task import Task
from ..utilities.process.mp import leave_process from ..utilities.process.mp import leave_process
from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary, walk_nested_dict_tuple_list from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary, walk_nested_dict_tuple_list
from ..utilities.version import Version
class PipelineController(object): class PipelineController(object):
@ -66,6 +68,7 @@ class PipelineController(object):
_status_change_callbacks = {} # Node.name: Callable[PipelineController, PipelineController.Node, str] _status_change_callbacks = {} # Node.name: Callable[PipelineController, PipelineController.Node, str]
_final_failure = {} # Node.name: bool _final_failure = {} # Node.name: bool
_task_template_header = CreateFromFunction.default_task_template_header _task_template_header = CreateFromFunction.default_task_template_header
_default_pipeline_version = "1.0.0"
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"] valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
@ -132,11 +135,11 @@ class PipelineController(object):
self, self,
name, # type: str name, # type: str
project, # type: str project, # type: str
version, # type: str version=None, # type: Optional[str]
pool_frequency=0.2, # type: float pool_frequency=0.2, # type: float
add_pipeline_tags=False, # type: bool add_pipeline_tags=False, # type: bool
target_project=True, # type: Optional[Union[str, bool]] target_project=True, # type: Optional[Union[str, bool]]
auto_version_bump=True, # type: bool auto_version_bump=None, # type: Optional[bool]
abort_on_failure=False, # type: bool abort_on_failure=False, # type: bool
add_run_number=True, # type: bool add_run_number=True, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
@ -156,14 +159,16 @@ class PipelineController(object):
:param name: Provide pipeline name (if main Task exists it overrides its name) :param name: Provide pipeline name (if main Task exists it overrides its name)
:param project: Provide project storing the pipeline (if main Task exists it overrides its project) :param project: Provide project storing the pipeline (if main Task exists it overrides its project)
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline :param version: Pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2' template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
If not set, find the latest version of the pipeline and increment it. If no such version is found,
default to '1.0.0'
:param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states. :param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all :param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
steps (Tasks) created by this pipeline. steps (Tasks) created by this pipeline.
:param str target_project: If provided, all pipeline steps are cloned into the target project. :param str target_project: If provided, all pipeline steps are cloned into the target project.
If True, pipeline steps are stored into the pipeline project If True, pipeline steps are stored into the pipeline project
:param bool auto_version_bump: If True (default), if the same pipeline version already exists :param bool auto_version_bump: (Deprecated) If True, if the same pipeline version already exists
(with any difference from the current one), the current pipeline version will be bumped to a new version (with any difference from the current one), the current pipeline version will be bumped to a new version
version bump examples: 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11 etc. version bump examples: 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11 etc.
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline :param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
@ -205,6 +210,7 @@ class PipelineController(object):
Example remote url: 'https://github.com/user/repo.git' Example remote url: 'https://github.com/user/repo.git'
Example local repo copy: './repo' -> will automatically store the remote Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy repo url and commit ID based on the locally cloned copy
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used) :param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used) :param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one :param artifact_serialization_function: A serialization function that takes one
@ -229,16 +235,18 @@ class PipelineController(object):
import dill import dill
return dill.loads(bytes_) return dill.loads(bytes_)
""" """
if auto_version_bump is not None:
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
self._nodes = {} self._nodes = {}
self._running_nodes = [] self._running_nodes = []
self._start_time = None self._start_time = None
self._pipeline_time_limit = None self._pipeline_time_limit = None
self._default_execution_queue = None self._default_execution_queue = None
self._version = str(version).strip() self._version = str(version).strip() if version else None
if not self._version or not all(i and i.isnumeric() for i in self._version.split('.')): if self._version and not Version.is_valid_version_string(self._version):
raise ValueError( raise ValueError(
"Pipeline version has to be in a semantic version form, " "Setting non-semantic dataset version '{}'".format(self._version)
"examples: version='1.0.1', version='1.2', version='23'") )
self._pool_frequency = pool_frequency * 60. self._pool_frequency = pool_frequency * 60.
self._thread = None self._thread = None
self._pipeline_args = dict() self._pipeline_args = dict()
@ -256,7 +264,6 @@ class PipelineController(object):
self._step_ref_pattern = re.compile(self._step_pattern) self._step_ref_pattern = re.compile(self._step_pattern)
self._reporting_lock = RLock() self._reporting_lock = RLock()
self._pipeline_task_status_failed = None self._pipeline_task_status_failed = None
self._auto_version_bump = bool(auto_version_bump)
self._mock_execution = False # used for nested pipelines (eager execution) self._mock_execution = False # used for nested pipelines (eager execution)
self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17")) self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17"))
self._last_progress_update_time = 0 self._last_progress_update_time = 0
@ -271,6 +278,12 @@ class PipelineController(object):
parent_project = None parent_project = None
project_name = project or 'Pipelines' project_name = project or 'Pipelines'
# if user disabled the auto-repo, we force local script storage (repo="" or repo=False)
set_force_local_repo = False
if Task.running_locally() and repo is not None and not repo:
Task.force_store_standalone_script(force=True)
set_force_local_repo = True
self._task = Task.init( self._task = Task.init(
project_name=project_name, project_name=project_name,
task_name=task_name, task_name=task_name,
@ -278,6 +291,13 @@ class PipelineController(object):
auto_resource_monitoring=False, auto_resource_monitoring=False,
reuse_last_task_id=False reuse_last_task_id=False
) )
# if user disabled the auto-repo, set it back to False (just in case)
if set_force_local_repo:
# noinspection PyProtectedMember
self._task._wait_for_repo_detection(timeout=300.)
Task.force_store_standalone_script(force=False)
# make sure project is hidden # make sure project is hidden
if self._pipeline_as_sub_project: if self._pipeline_as_sub_project:
get_or_create_project( get_or_create_project(
@ -287,7 +307,6 @@ class PipelineController(object):
project_id=self._task.project, system_tags=self._project_system_tags) project_id=self._task.project, system_tags=self._project_system_tags)
self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag]) self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag])
self._task.set_user_properties(version=self._version)
self._task.set_base_docker( self._task.set_base_docker(
docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script
) )
@ -1259,7 +1278,7 @@ class PipelineController(object):
:param name: String name of the parameter. :param name: String name of the parameter.
:param default: Default value to be put as the default value (can be later changed in the UI) :param default: Default value to be put as the default value (can be later changed in the UI)
:param description: String description of the parameter and its usage in the pipeline :param description: String description of the parameter and its usage in the pipeline
:param param_type: Optional, parameter type information (to used as hint for casting and description) :param param_type: Optional, parameter type information (to be used as hint for casting and description)
""" """
self._pipeline_args[str(name)] = default self._pipeline_args[str(name)] = default
if description: if description:
@ -1443,7 +1462,7 @@ class PipelineController(object):
# make sure we have a unique version number (auto bump version if needed) # make sure we have a unique version number (auto bump version if needed)
# only needed when manually (from code) creating pipelines # only needed when manually (from code) creating pipelines
self._verify_pipeline_version() self._handle_pipeline_version()
# noinspection PyProtectedMember # noinspection PyProtectedMember
pipeline_hash = self._get_task_hash() pipeline_hash = self._get_task_hash()
@ -1451,6 +1470,7 @@ class PipelineController(object):
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._task._set_runtime_properties({ self._task._set_runtime_properties({
self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version), self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version),
"version": self._version
}) })
else: else:
self._task.connect_configuration(pipeline_dag, name=self._config_section) self._task.connect_configuration(pipeline_dag, name=self._config_section)
@ -1482,74 +1502,25 @@ class PipelineController(object):
return params, pipeline_dag return params, pipeline_dag
def _verify_pipeline_version(self): def _handle_pipeline_version(self):
# if no version bump needed, just set the property if not self._version:
if not self._auto_version_bump: # noinspection PyProtectedMember
self._task.set_user_properties(version=self._version) self._version = self._task._get_runtime_properties().get("version")
return if not self._version:
previous_pipeline_tasks = Task._query_tasks(
# check if pipeline version exists, if it does increase version project=[self._task.project],
pipeline_hash = self._get_task_hash() fetch_only_first_page=True,
# noinspection PyProtectedMember only_fields=["runtime.version"],
existing_tasks = Task._query_tasks( order_by=["-last_update"],
project=[self._task.project], task_name=exact_match_regex(self._task.name), system_tags=[self._tag],
type=[str(self._task.task_type)], search_hidden=True,
system_tags=["__$all", self._tag, "__$not", Task.archived_tag], _allow_extra_fields_=True
_all_=dict(fields=['runtime.{}'.format(self._runtime_property_hash)],
pattern=":{}".format(self._version)),
only_fields=['id', 'runtime'],
)
if existing_tasks:
# check if hash match the current version.
matched = True
for t in existing_tasks:
h, _, v = t.runtime.get(self._runtime_property_hash, '').partition(':')
if v == self._version:
matched = bool(h == pipeline_hash)
break
# if hash did not match, look for the highest version
if not matched:
# noinspection PyProtectedMember
existing_tasks = Task._query_tasks(
project=[self._task.project], task_name=exact_match_regex(self._task.name),
type=[str(self._task.task_type)],
system_tags=["__$all", self._tag, "__$not", Task.archived_tag],
only_fields=['id', 'hyperparams', 'runtime'],
) )
found_match_version = False for previous_pipeline_task in previous_pipeline_tasks:
existing_versions = set([self._version]) # noqa if previous_pipeline_task.runtime.get("version"):
for t in existing_tasks: self._version = str(Version(previous_pipeline_task.runtime.get("version")).get_next_version())
# exclude ourselves break
if t.id == self._task.id: self._version = self._version or self._default_pipeline_version
continue
if not t.hyperparams:
continue
v = t.hyperparams.get('properties', {}).get('version')
if v:
existing_versions.add(v.value)
if t.runtime:
h, _, _ = t.runtime.get(self._runtime_property_hash, '').partition(':')
if h == pipeline_hash:
self._version = v.value
found_match_version = True
break
# match to the version we found:
if found_match_version:
getLogger('clearml.automation.controller').info(
'Existing Pipeline found, matching version to: {}'.format(self._version))
else:
# if we did not find a matched pipeline version, get the max one and bump the version by 1
while True:
v = self._version.split('.')
self._version = '.'.join(v[:-1] + [str(int(v[-1]) + 1)])
if self._version not in existing_versions:
break
getLogger('clearml.automation.controller').info(
'No matching Pipelines found, bump new version to: {}'.format(self._version))
self._task.set_user_properties(version=self._version)
def _get_task_hash(self): def _get_task_hash(self):
params_override = dict(**(self._task.get_parameters() or {})) params_override = dict(**(self._task.get_parameters() or {}))
@ -3054,7 +3025,8 @@ class PipelineController(object):
name=artifact_name, name=artifact_name,
artifact_object=artifact_object, artifact_object=artifact_object,
wait_on_upload=True, wait_on_upload=True,
extension_name=".pkl" if isinstance(artifact_object, dict) else None, extension_name=".pkl" if isinstance(artifact_object, dict) and not self._artifact_serialization_function else None,
serialization_function=self._artifact_serialization_function
) )
@ -3075,7 +3047,7 @@ class PipelineDecorator(PipelineController):
self, self,
name, # type: str name, # type: str
project, # type: str project, # type: str
version, # type: str version=None, # type: Optional[str]
pool_frequency=0.2, # type: float pool_frequency=0.2, # type: float
add_pipeline_tags=False, # type: bool add_pipeline_tags=False, # type: bool
target_project=None, # type: Optional[str] target_project=None, # type: Optional[str]
@ -3098,8 +3070,10 @@ class PipelineDecorator(PipelineController):
:param name: Provide pipeline name (if main Task exists it overrides its name) :param name: Provide pipeline name (if main Task exists it overrides its name)
:param project: Provide project storing the pipeline (if main Task exists it overrides its project) :param project: Provide project storing the pipeline (if main Task exists it overrides its project)
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline :param version: Pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2' template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
If not set, find the latest version of the pipeline and increment it. If no such version is found,
default to '1.0.0'
:param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states. :param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all :param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
steps (Tasks) created by this pipeline. steps (Tasks) created by this pipeline.
@ -3143,6 +3117,7 @@ class PipelineDecorator(PipelineController):
Example remote url: 'https://github.com/user/repo.git' Example remote url: 'https://github.com/user/repo.git'
Example local repo copy: './repo' -> will automatically store the remote Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy repo url and commit ID based on the locally cloned copy
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used) :param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used) :param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one :param artifact_serialization_function: A serialization function that takes one
@ -3156,7 +3131,6 @@ class PipelineDecorator(PipelineController):
def serialize(obj): def serialize(obj):
import dill import dill
return dill.dumps(obj) return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object. which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function. All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
@ -3906,7 +3880,7 @@ class PipelineDecorator(PipelineController):
_func=None, *, # noqa _func=None, *, # noqa
name, # type: str name, # type: str
project, # type: str project, # type: str
version, # type: str version=None, # type: Optional[str]
return_value=None, # type: Optional[str] return_value=None, # type: Optional[str]
default_queue=None, # type: Optional[str] default_queue=None, # type: Optional[str]
pool_frequency=0.2, # type: float pool_frequency=0.2, # type: float
@ -3935,8 +3909,10 @@ class PipelineDecorator(PipelineController):
:param name: Provide pipeline name (if main Task exists it overrides its name) :param name: Provide pipeline name (if main Task exists it overrides its name)
:param project: Provide project storing the pipeline (if main Task exists it overrides its project) :param project: Provide project storing the pipeline (if main Task exists it overrides its project)
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline :param version: Pipeline version. This version allows to uniquely identify the pipeline
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2' template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
If not set, find the latest version of the pipeline and increment it. If no such version is found,
default to '1.0.0'
:param return_value: Optional, Provide an artifact name to store the pipeline function return object :param return_value: Optional, Provide an artifact name to store the pipeline function return object
Notice, If not provided the pipeline will not store the pipeline function return value. Notice, If not provided the pipeline will not store the pipeline function return value.
:param default_queue: default pipeline step queue :param default_queue: default pipeline step queue
@ -4009,6 +3985,7 @@ class PipelineDecorator(PipelineController):
Example remote url: 'https://github.com/user/repo.git' Example remote url: 'https://github.com/user/repo.git'
Example local repo copy: './repo' -> will automatically store the remote Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy repo url and commit ID based on the locally cloned copy
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used) :param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used) :param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param artifact_serialization_function: A serialization function that takes one :param artifact_serialization_function: A serialization function that takes one
@ -4022,7 +3999,6 @@ class PipelineDecorator(PipelineController):
def serialize(obj): def serialize(obj):
import dill import dill
return dill.dumps(obj) return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object. which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function. All parameter/return artifacts fetched by the pipeline will be deserialized using this function.