mirror of
https://github.com/clearml/clearml
synced 2025-03-13 07:08:24 +00:00
Refactor pipeline code
This commit is contained in:
parent
4a91843559
commit
2f4f11aadb
@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import six
|
||||
import warnings
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
@ -31,6 +32,7 @@ from ..storage.util import hash_dict
|
||||
from ..task import Task
|
||||
from ..utilities.process.mp import leave_process
|
||||
from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary, walk_nested_dict_tuple_list
|
||||
from ..utilities.version import Version
|
||||
|
||||
|
||||
class PipelineController(object):
|
||||
@ -66,6 +68,7 @@ class PipelineController(object):
|
||||
_status_change_callbacks = {} # Node.name: Callable[PipelineController, PipelineController.Node, str]
|
||||
_final_failure = {} # Node.name: bool
|
||||
_task_template_header = CreateFromFunction.default_task_template_header
|
||||
_default_pipeline_version = "1.0.0"
|
||||
|
||||
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
|
||||
|
||||
@ -132,11 +135,11 @@ class PipelineController(object):
|
||||
self,
|
||||
name, # type: str
|
||||
project, # type: str
|
||||
version, # type: str
|
||||
version=None, # type: Optional[str]
|
||||
pool_frequency=0.2, # type: float
|
||||
add_pipeline_tags=False, # type: bool
|
||||
target_project=True, # type: Optional[Union[str, bool]]
|
||||
auto_version_bump=True, # type: bool
|
||||
auto_version_bump=None, # type: Optional[bool]
|
||||
abort_on_failure=False, # type: bool
|
||||
add_run_number=True, # type: bool
|
||||
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
|
||||
@ -156,14 +159,16 @@ class PipelineController(object):
|
||||
|
||||
:param name: Provide pipeline name (if main Task exists it overrides its name)
|
||||
:param project: Provide project storing the pipeline (if main Task exists it overrides its project)
|
||||
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline
|
||||
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'
|
||||
:param version: Pipeline version. This version allows to uniquely identify the pipeline
|
||||
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
|
||||
If not set, find the latest version of the pipeline and increment it. If no such version is found,
|
||||
default to '1.0.0'
|
||||
:param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
|
||||
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
|
||||
steps (Tasks) created by this pipeline.
|
||||
:param str target_project: If provided, all pipeline steps are cloned into the target project.
|
||||
If True, pipeline steps are stored into the pipeline project
|
||||
:param bool auto_version_bump: If True (default), if the same pipeline version already exists
|
||||
:param bool auto_version_bump: (Deprecated) If True, if the same pipeline version already exists
|
||||
(with any difference from the current one), the current pipeline version will be bumped to a new version
|
||||
version bump examples: 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11 etc.
|
||||
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
||||
@ -205,6 +210,7 @@ class PipelineController(object):
|
||||
Example remote url: 'https://github.com/user/repo.git'
|
||||
Example local repo copy: './repo' -> will automatically store the remote
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
Use empty string ("") to disable any repository auto-detection
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
@ -229,16 +235,18 @@ class PipelineController(object):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
"""
|
||||
if auto_version_bump is not None:
|
||||
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
|
||||
self._nodes = {}
|
||||
self._running_nodes = []
|
||||
self._start_time = None
|
||||
self._pipeline_time_limit = None
|
||||
self._default_execution_queue = None
|
||||
self._version = str(version).strip()
|
||||
if not self._version or not all(i and i.isnumeric() for i in self._version.split('.')):
|
||||
self._version = str(version).strip() if version else None
|
||||
if self._version and not Version.is_valid_version_string(self._version):
|
||||
raise ValueError(
|
||||
"Pipeline version has to be in a semantic version form, "
|
||||
"examples: version='1.0.1', version='1.2', version='23'")
|
||||
"Setting non-semantic dataset version '{}'".format(self._version)
|
||||
)
|
||||
self._pool_frequency = pool_frequency * 60.
|
||||
self._thread = None
|
||||
self._pipeline_args = dict()
|
||||
@ -256,7 +264,6 @@ class PipelineController(object):
|
||||
self._step_ref_pattern = re.compile(self._step_pattern)
|
||||
self._reporting_lock = RLock()
|
||||
self._pipeline_task_status_failed = None
|
||||
self._auto_version_bump = bool(auto_version_bump)
|
||||
self._mock_execution = False # used for nested pipelines (eager execution)
|
||||
self._pipeline_as_sub_project = bool(Session.check_min_api_server_version("2.17"))
|
||||
self._last_progress_update_time = 0
|
||||
@ -271,6 +278,12 @@ class PipelineController(object):
|
||||
parent_project = None
|
||||
project_name = project or 'Pipelines'
|
||||
|
||||
# if user disabled the auto-repo, we force local script storage (repo="" or repo=False)
|
||||
set_force_local_repo = False
|
||||
if Task.running_locally() and repo is not None and not repo:
|
||||
Task.force_store_standalone_script(force=True)
|
||||
set_force_local_repo = True
|
||||
|
||||
self._task = Task.init(
|
||||
project_name=project_name,
|
||||
task_name=task_name,
|
||||
@ -278,6 +291,13 @@ class PipelineController(object):
|
||||
auto_resource_monitoring=False,
|
||||
reuse_last_task_id=False
|
||||
)
|
||||
|
||||
# if user disabled the auto-repo, set it back to False (just in case)
|
||||
if set_force_local_repo:
|
||||
# noinspection PyProtectedMember
|
||||
self._task._wait_for_repo_detection(timeout=300.)
|
||||
Task.force_store_standalone_script(force=False)
|
||||
|
||||
# make sure project is hidden
|
||||
if self._pipeline_as_sub_project:
|
||||
get_or_create_project(
|
||||
@ -287,7 +307,6 @@ class PipelineController(object):
|
||||
project_id=self._task.project, system_tags=self._project_system_tags)
|
||||
|
||||
self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag])
|
||||
self._task.set_user_properties(version=self._version)
|
||||
self._task.set_base_docker(
|
||||
docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script
|
||||
)
|
||||
@ -1259,7 +1278,7 @@ class PipelineController(object):
|
||||
:param name: String name of the parameter.
|
||||
:param default: Default value to be put as the default value (can be later changed in the UI)
|
||||
:param description: String description of the parameter and its usage in the pipeline
|
||||
:param param_type: Optional, parameter type information (to used as hint for casting and description)
|
||||
:param param_type: Optional, parameter type information (to be used as hint for casting and description)
|
||||
"""
|
||||
self._pipeline_args[str(name)] = default
|
||||
if description:
|
||||
@ -1443,7 +1462,7 @@ class PipelineController(object):
|
||||
|
||||
# make sure we have a unique version number (auto bump version if needed)
|
||||
# only needed when manually (from code) creating pipelines
|
||||
self._verify_pipeline_version()
|
||||
self._handle_pipeline_version()
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
pipeline_hash = self._get_task_hash()
|
||||
@ -1451,6 +1470,7 @@ class PipelineController(object):
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_runtime_properties({
|
||||
self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version),
|
||||
"version": self._version
|
||||
})
|
||||
else:
|
||||
self._task.connect_configuration(pipeline_dag, name=self._config_section)
|
||||
@ -1482,74 +1502,25 @@ class PipelineController(object):
|
||||
|
||||
return params, pipeline_dag
|
||||
|
||||
def _verify_pipeline_version(self):
|
||||
# if no version bump needed, just set the property
|
||||
if not self._auto_version_bump:
|
||||
self._task.set_user_properties(version=self._version)
|
||||
return
|
||||
|
||||
# check if pipeline version exists, if it does increase version
|
||||
pipeline_hash = self._get_task_hash()
|
||||
def _handle_pipeline_version(self):
|
||||
if not self._version:
|
||||
# noinspection PyProtectedMember
|
||||
existing_tasks = Task._query_tasks(
|
||||
project=[self._task.project], task_name=exact_match_regex(self._task.name),
|
||||
type=[str(self._task.task_type)],
|
||||
system_tags=["__$all", self._tag, "__$not", Task.archived_tag],
|
||||
_all_=dict(fields=['runtime.{}'.format(self._runtime_property_hash)],
|
||||
pattern=":{}".format(self._version)),
|
||||
only_fields=['id', 'runtime'],
|
||||
self._version = self._task._get_runtime_properties().get("version")
|
||||
if not self._version:
|
||||
previous_pipeline_tasks = Task._query_tasks(
|
||||
project=[self._task.project],
|
||||
fetch_only_first_page=True,
|
||||
only_fields=["runtime.version"],
|
||||
order_by=["-last_update"],
|
||||
system_tags=[self._tag],
|
||||
search_hidden=True,
|
||||
_allow_extra_fields_=True
|
||||
)
|
||||
if existing_tasks:
|
||||
# check if hash match the current version.
|
||||
matched = True
|
||||
for t in existing_tasks:
|
||||
h, _, v = t.runtime.get(self._runtime_property_hash, '').partition(':')
|
||||
if v == self._version:
|
||||
matched = bool(h == pipeline_hash)
|
||||
for previous_pipeline_task in previous_pipeline_tasks:
|
||||
if previous_pipeline_task.runtime.get("version"):
|
||||
self._version = str(Version(previous_pipeline_task.runtime.get("version")).get_next_version())
|
||||
break
|
||||
# if hash did not match, look for the highest version
|
||||
if not matched:
|
||||
# noinspection PyProtectedMember
|
||||
existing_tasks = Task._query_tasks(
|
||||
project=[self._task.project], task_name=exact_match_regex(self._task.name),
|
||||
type=[str(self._task.task_type)],
|
||||
system_tags=["__$all", self._tag, "__$not", Task.archived_tag],
|
||||
only_fields=['id', 'hyperparams', 'runtime'],
|
||||
)
|
||||
found_match_version = False
|
||||
existing_versions = set([self._version]) # noqa
|
||||
for t in existing_tasks:
|
||||
# exclude ourselves
|
||||
if t.id == self._task.id:
|
||||
continue
|
||||
if not t.hyperparams:
|
||||
continue
|
||||
v = t.hyperparams.get('properties', {}).get('version')
|
||||
if v:
|
||||
existing_versions.add(v.value)
|
||||
if t.runtime:
|
||||
h, _, _ = t.runtime.get(self._runtime_property_hash, '').partition(':')
|
||||
if h == pipeline_hash:
|
||||
self._version = v.value
|
||||
found_match_version = True
|
||||
break
|
||||
|
||||
# match to the version we found:
|
||||
if found_match_version:
|
||||
getLogger('clearml.automation.controller').info(
|
||||
'Existing Pipeline found, matching version to: {}'.format(self._version))
|
||||
else:
|
||||
# if we did not find a matched pipeline version, get the max one and bump the version by 1
|
||||
while True:
|
||||
v = self._version.split('.')
|
||||
self._version = '.'.join(v[:-1] + [str(int(v[-1]) + 1)])
|
||||
if self._version not in existing_versions:
|
||||
break
|
||||
|
||||
getLogger('clearml.automation.controller').info(
|
||||
'No matching Pipelines found, bump new version to: {}'.format(self._version))
|
||||
|
||||
self._task.set_user_properties(version=self._version)
|
||||
self._version = self._version or self._default_pipeline_version
|
||||
|
||||
def _get_task_hash(self):
|
||||
params_override = dict(**(self._task.get_parameters() or {}))
|
||||
@ -3054,7 +3025,8 @@ class PipelineController(object):
|
||||
name=artifact_name,
|
||||
artifact_object=artifact_object,
|
||||
wait_on_upload=True,
|
||||
extension_name=".pkl" if isinstance(artifact_object, dict) else None,
|
||||
extension_name=".pkl" if isinstance(artifact_object, dict) and not self._artifact_serialization_function else None,
|
||||
serialization_function=self._artifact_serialization_function
|
||||
)
|
||||
|
||||
|
||||
@ -3075,7 +3047,7 @@ class PipelineDecorator(PipelineController):
|
||||
self,
|
||||
name, # type: str
|
||||
project, # type: str
|
||||
version, # type: str
|
||||
version=None, # type: Optional[str]
|
||||
pool_frequency=0.2, # type: float
|
||||
add_pipeline_tags=False, # type: bool
|
||||
target_project=None, # type: Optional[str]
|
||||
@ -3098,8 +3070,10 @@ class PipelineDecorator(PipelineController):
|
||||
|
||||
:param name: Provide pipeline name (if main Task exists it overrides its name)
|
||||
:param project: Provide project storing the pipeline (if main Task exists it overrides its project)
|
||||
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline
|
||||
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'
|
||||
:param version: Pipeline version. This version allows to uniquely identify the pipeline
|
||||
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
|
||||
If not set, find the latest version of the pipeline and increment it. If no such version is found,
|
||||
default to '1.0.0'
|
||||
:param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
|
||||
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
|
||||
steps (Tasks) created by this pipeline.
|
||||
@ -3143,6 +3117,7 @@ class PipelineDecorator(PipelineController):
|
||||
Example remote url: 'https://github.com/user/repo.git'
|
||||
Example local repo copy: './repo' -> will automatically store the remote
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
Use empty string ("") to disable any repository auto-detection
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
@ -3156,7 +3131,6 @@ class PipelineDecorator(PipelineController):
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
@ -3906,7 +3880,7 @@ class PipelineDecorator(PipelineController):
|
||||
_func=None, *, # noqa
|
||||
name, # type: str
|
||||
project, # type: str
|
||||
version, # type: str
|
||||
version=None, # type: Optional[str]
|
||||
return_value=None, # type: Optional[str]
|
||||
default_queue=None, # type: Optional[str]
|
||||
pool_frequency=0.2, # type: float
|
||||
@ -3935,8 +3909,10 @@ class PipelineDecorator(PipelineController):
|
||||
|
||||
:param name: Provide pipeline name (if main Task exists it overrides its name)
|
||||
:param project: Provide project storing the pipeline (if main Task exists it overrides its project)
|
||||
:param version: Must provide pipeline version. This version allows to uniquely identify the pipeline
|
||||
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'
|
||||
:param version: Pipeline version. This version allows to uniquely identify the pipeline
|
||||
template execution. Examples for semantic versions: version='1.0.1' , version='23', version='1.2'.
|
||||
If not set, find the latest version of the pipeline and increment it. If no such version is found,
|
||||
default to '1.0.0'
|
||||
:param return_value: Optional, Provide an artifact name to store the pipeline function return object
|
||||
Notice, If not provided the pipeline will not store the pipeline function return value.
|
||||
:param default_queue: default pipeline step queue
|
||||
@ -4009,6 +3985,7 @@ class PipelineDecorator(PipelineController):
|
||||
Example remote url: 'https://github.com/user/repo.git'
|
||||
Example local repo copy: './repo' -> will automatically store the remote
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
Use empty string ("") to disable any repository auto-detection
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
:param artifact_serialization_function: A serialization function that takes one
|
||||
@ -4022,7 +3999,6 @@ class PipelineDecorator(PipelineController):
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
|
Loading…
Reference in New Issue
Block a user