mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Add support for .get ing pipelines and enqueue-ing them
This commit is contained in:
parent
9d57dad652
commit
c8c8a1224e
@ -23,7 +23,7 @@ from .. import Logger
|
|||||||
from ..automation import ClearmlJob
|
from ..automation import ClearmlJob
|
||||||
from ..backend_api import Session
|
from ..backend_api import Session
|
||||||
from ..backend_interface.task.populate import CreateFromFunction
|
from ..backend_interface.task.populate import CreateFromFunction
|
||||||
from ..backend_interface.util import get_or_create_project
|
from ..backend_interface.util import get_or_create_project, mutually_exclusive
|
||||||
from ..config import get_remote_task_id
|
from ..config import get_remote_task_id
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
from ..errors import UsageError
|
from ..errors import UsageError
|
||||||
@ -1292,6 +1292,136 @@ class PipelineController(object):
|
|||||||
"""
|
"""
|
||||||
return self._pipeline_args
|
return self._pipeline_args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enqueue(cls, pipeline_controller, queue_name=None, queue_id=None, force=False):
|
||||||
|
# type: (Union[PipelineController, str], Optional[str], Optional[str], bool) -> Any
|
||||||
|
"""
|
||||||
|
Enqueue a PipelineController for execution, by adding it to an execution queue.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
A worker daemon must be listening at the queue for the worker to fetch the Task and execute it,
|
||||||
|
see `ClearML Agent <../clearml_agent>`_ in the ClearML Documentation.
|
||||||
|
|
||||||
|
:param pipeline_controller: The PipelineController to enqueue. Specify a PipelineController object or PipelineController ID
|
||||||
|
:param queue_name: The name of the queue. If not specified, then ``queue_id`` must be specified.
|
||||||
|
:param queue_id: The ID of the queue. If not specified, then ``queue_name`` must be specified.
|
||||||
|
:param bool force: If True, reset the PipelineController if necessary before enqueuing it
|
||||||
|
|
||||||
|
:return: An enqueue JSON response.
|
||||||
|
|
||||||
|
.. code-block:: javascript
|
||||||
|
|
||||||
|
{
|
||||||
|
"queued": 1,
|
||||||
|
"updated": 1,
|
||||||
|
"fields": {
|
||||||
|
"status": "queued",
|
||||||
|
"status_reason": "",
|
||||||
|
"status_message": "",
|
||||||
|
"status_changed": "2020-02-24T15:05:35.426770+00:00",
|
||||||
|
"last_update": "2020-02-24T15:05:35.426770+00:00",
|
||||||
|
"execution.queue": "2bd96ab2d9e54b578cc2fb195e52c7cf"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- ``queued`` - The number of Tasks enqueued (an integer or ``null``).
|
||||||
|
- ``updated`` - The number of Tasks updated (an integer or ``null``).
|
||||||
|
- ``fields``
|
||||||
|
|
||||||
|
- ``status`` - The status of the experiment.
|
||||||
|
- ``status_reason`` - The reason for the last status change.
|
||||||
|
- ``status_message`` - Information about the status.
|
||||||
|
- ``status_changed`` - The last status change date and time (ISO 8601 format).
|
||||||
|
- ``last_update`` - The last Task update time, including Task creation, update, change, or events for this task (ISO 8601 format).
|
||||||
|
- ``execution.queue`` - The ID of the queue where the Task is enqueued. ``null`` indicates not enqueued.
|
||||||
|
"""
|
||||||
|
pipeline_controller = (
|
||||||
|
pipeline_controller
|
||||||
|
if isinstance(pipeline_controller, PipelineController)
|
||||||
|
else cls.get(pipeline_id=pipeline_controller)
|
||||||
|
)
|
||||||
|
return Task.enqueue(pipeline_controller._task, queue_name=queue_name, queue_id=queue_id, force=force)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(
|
||||||
|
cls,
|
||||||
|
pipeline_id=None, # type: Optional[str]
|
||||||
|
pipeline_project=None, # type: Optional[str]
|
||||||
|
pipeline_name=None, # type: Optional[str]
|
||||||
|
pipeline_version=None, # type: Optional[str]
|
||||||
|
pipeline_tags=None, # type: Optional[Sequence[str]]
|
||||||
|
shallow_search=False # type: bool
|
||||||
|
):
|
||||||
|
# type: (...) -> "PipelineController"
|
||||||
|
"""
|
||||||
|
Get a specific PipelineController. If multiple pipeline controllers are found, the pipeline controller
|
||||||
|
with the highest semantic version is returned. If no semantic version is found, the most recently
|
||||||
|
updated pipeline controller is returned. This function raises aan Exception if no pipeline controller
|
||||||
|
was found
|
||||||
|
|
||||||
|
Note: In order to run the pipeline controller returned by this function, use PipelineController.enqueue
|
||||||
|
|
||||||
|
:param pipeline_id: Requested PipelineController ID
|
||||||
|
:param pipeline_project: Requested PipelineController project
|
||||||
|
:param pipeline_name: Requested PipelineController name
|
||||||
|
:param pipeline_tags: Requested PipelineController tags (list of tag strings)
|
||||||
|
:param shallow_search: If True, search only the first 500 results (first page)
|
||||||
|
"""
|
||||||
|
mutually_exclusive(pipeline_id=pipeline_id, pipeline_project=pipeline_project, _require_at_least_one=False)
|
||||||
|
mutually_exclusive(pipeline_id=pipeline_id, pipeline_name=pipeline_name, _require_at_least_one=False)
|
||||||
|
if not pipeline_id:
|
||||||
|
pipeline_project_hidden = "{}/.pipelines/{}".format(pipeline_project, pipeline_name)
|
||||||
|
name_with_runtime_number_regex = r"^{}( #[0-9]+)*$".format(re.escape(pipeline_name))
|
||||||
|
pipelines = Task._query_tasks(
|
||||||
|
pipeline_project=[pipeline_project_hidden],
|
||||||
|
task_name=name_with_runtime_number_regex,
|
||||||
|
fetch_only_first_page=False if not pipeline_version else shallow_search,
|
||||||
|
only_fields=["id"] if not pipeline_version else ["id", "runtime.version"],
|
||||||
|
system_tags=[cls._tag],
|
||||||
|
order_by=["-last_update"],
|
||||||
|
tags=pipeline_tags,
|
||||||
|
search_hidden=True,
|
||||||
|
_allow_extra_fields_=True,
|
||||||
|
)
|
||||||
|
if pipelines:
|
||||||
|
if not pipeline_version:
|
||||||
|
pipeline_id = pipelines[0].id
|
||||||
|
current_version = None
|
||||||
|
for pipeline in pipelines:
|
||||||
|
if not pipeline.runtime:
|
||||||
|
continue
|
||||||
|
candidate_version = pipeline.runtime.get("version")
|
||||||
|
if not candidate_version or not Version.is_valid_version_string(candidate_version):
|
||||||
|
continue
|
||||||
|
if not current_version or Version(candidate_version) > current_version:
|
||||||
|
current_version = Version(candidate_version)
|
||||||
|
pipeline_id = pipeline.id
|
||||||
|
else:
|
||||||
|
for pipeline in pipelines:
|
||||||
|
if pipeline.runtime.get("version") == pipeline_version:
|
||||||
|
pipeline_id = pipeline.id
|
||||||
|
break
|
||||||
|
if not pipeline_id:
|
||||||
|
error_msg = "Could not find dataset with pipeline_project={}, pipeline_name={}".format(pipeline_project, pipeline_name)
|
||||||
|
if pipeline_version:
|
||||||
|
error_msg += ", pipeline_version={}".format(pipeline_version)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
pipeline_task = Task.get_task(task_id=pipeline_id)
|
||||||
|
pipeline_object = cls.__new__(cls)
|
||||||
|
pipeline_object._task = pipeline_task
|
||||||
|
pipeline_object._nodes = {}
|
||||||
|
pipeline_object._running_nodes = []
|
||||||
|
try:
|
||||||
|
pipeline_object._deserialize(pipeline_task._get_configuration_dict(cls._config_section))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return pipeline_object
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
# type: () -> str
|
||||||
|
return self._task.id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self):
|
def tags(self):
|
||||||
# type: () -> List[str]
|
# type: () -> List[str]
|
||||||
|
@ -1213,8 +1213,8 @@ class Task(_Task):
|
|||||||
return cloned_task
|
return cloned_task
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enqueue(cls, task, queue_name=None, queue_id=None):
|
def enqueue(cls, task, queue_name=None, queue_id=None, force=False):
|
||||||
# type: (Union[Task, str], Optional[str], Optional[str]) -> Any
|
# type: (Union[Task, str], Optional[str], Optional[str], bool) -> Any
|
||||||
"""
|
"""
|
||||||
Enqueue a Task for execution, by adding it to an execution queue.
|
Enqueue a Task for execution, by adding it to an execution queue.
|
||||||
|
|
||||||
@ -1225,6 +1225,7 @@ class Task(_Task):
|
|||||||
:param Task/str task: The Task to enqueue. Specify a Task object or Task ID.
|
:param Task/str task: The Task to enqueue. Specify a Task object or Task ID.
|
||||||
:param str queue_name: The name of the queue. If not specified, then ``queue_id`` must be specified.
|
:param str queue_name: The name of the queue. If not specified, then ``queue_id`` must be specified.
|
||||||
:param str queue_id: The ID of the queue. If not specified, then ``queue_name`` must be specified.
|
:param str queue_id: The ID of the queue. If not specified, then ``queue_name`` must be specified.
|
||||||
|
:param bool force: If True, reset the Task if necessary before enqueuing it
|
||||||
|
|
||||||
:return: An enqueue JSON response.
|
:return: An enqueue JSON response.
|
||||||
|
|
||||||
@ -1271,9 +1272,25 @@ class Task(_Task):
|
|||||||
raise ValueError('Could not find queue named "{}"'.format(queue_name))
|
raise ValueError('Could not find queue named "{}"'.format(queue_name))
|
||||||
|
|
||||||
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
||||||
res = cls._send(session=session, req=req)
|
exception = None
|
||||||
if not res.ok():
|
res = None
|
||||||
raise ValueError(res.response)
|
try:
|
||||||
|
res = cls._send(session=session, req=req)
|
||||||
|
ok = res.ok()
|
||||||
|
except Exception as e:
|
||||||
|
exception = e
|
||||||
|
ok = False
|
||||||
|
if not ok:
|
||||||
|
if not force:
|
||||||
|
if res:
|
||||||
|
raise ValueError(res.response)
|
||||||
|
raise exception
|
||||||
|
task = cls.get_task(task_id=task) if isinstance(task, str) else task
|
||||||
|
task.reset(set_started_on_success=False, force=True)
|
||||||
|
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
||||||
|
res = cls._send(session=session, req=req)
|
||||||
|
if not res.ok():
|
||||||
|
raise ValueError(res.response)
|
||||||
resp = res.response
|
resp = res.response
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
@ -5,20 +5,21 @@ from clearml import TaskTypes
|
|||||||
# Make the following function an independent pipeline component step
|
# Make the following function an independent pipeline component step
|
||||||
# notice all package imports inside the function will be automatically logged as
|
# notice all package imports inside the function will be automatically logged as
|
||||||
# required packages for the pipeline execution step
|
# required packages for the pipeline execution step
|
||||||
@PipelineDecorator.component(return_values=['data_frame'], cache=True, task_type=TaskTypes.data_processing)
|
@PipelineDecorator.component(return_values=["data_frame"], cache=True, task_type=TaskTypes.data_processing)
|
||||||
def step_one(pickle_data_url: str, extra: int = 43):
|
def step_one(pickle_data_url: str, extra: int = 43):
|
||||||
print('step_one')
|
print("step_one")
|
||||||
# make sure we have scikit-learn for this step, we need it to use to unpickle the object
|
# make sure we have scikit-learn for this step, we need it to use to unpickle the object
|
||||||
import sklearn # noqa
|
import sklearn # noqa
|
||||||
import pickle
|
import pickle
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from clearml import StorageManager
|
from clearml import StorageManager
|
||||||
|
|
||||||
local_iris_pkl = StorageManager.get_local_copy(remote_url=pickle_data_url)
|
local_iris_pkl = StorageManager.get_local_copy(remote_url=pickle_data_url)
|
||||||
with open(local_iris_pkl, 'rb') as f:
|
with open(local_iris_pkl, "rb") as f:
|
||||||
iris = pickle.load(f)
|
iris = pickle.load(f)
|
||||||
data_frame = pd.DataFrame(iris['data'], columns=iris['feature_names'])
|
data_frame = pd.DataFrame(iris["data"], columns=iris["feature_names"])
|
||||||
data_frame.columns += ['target']
|
data_frame.columns += ["target"]
|
||||||
data_frame['target'] = iris['target']
|
data_frame["target"] = iris["target"]
|
||||||
return data_frame
|
return data_frame
|
||||||
|
|
||||||
|
|
||||||
@ -28,18 +29,17 @@ def step_one(pickle_data_url: str, extra: int = 43):
|
|||||||
# Specifying `return_values` makes sure the function step can return an object to the pipeline logic
|
# Specifying `return_values` makes sure the function step can return an object to the pipeline logic
|
||||||
# In this case, the returned tuple will be stored as an artifact named "X_train, X_test, y_train, y_test"
|
# In this case, the returned tuple will be stored as an artifact named "X_train, X_test, y_train, y_test"
|
||||||
@PipelineDecorator.component(
|
@PipelineDecorator.component(
|
||||||
return_values=['X_train, X_test, y_train, y_test'], cache=True, task_type=TaskTypes.data_processing
|
return_values=["X_train", "X_test", "y_train", "y_test"], cache=True, task_type=TaskTypes.data_processing
|
||||||
)
|
)
|
||||||
def step_two(data_frame, test_size=0.2, random_state=42):
|
def step_two(data_frame, test_size=0.2, random_state=42):
|
||||||
print('step_two')
|
print("step_two")
|
||||||
# make sure we have pandas for this step, we need it to use the data_frame
|
# make sure we have pandas for this step, we need it to use the data_frame
|
||||||
import pandas as pd # noqa
|
import pandas as pd # noqa
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
y = data_frame['target']
|
|
||||||
X = data_frame[(c for c in data_frame.columns if c != 'target')]
|
y = data_frame["target"]
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X = data_frame[(c for c in data_frame.columns if c != "target")]
|
||||||
X, y, test_size=test_size, random_state=random_state
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
||||||
)
|
|
||||||
|
|
||||||
return X_train, X_test, y_train, y_test
|
return X_train, X_test, y_train, y_test
|
||||||
|
|
||||||
@ -49,37 +49,41 @@ def step_two(data_frame, test_size=0.2, random_state=42):
|
|||||||
# required packages for the pipeline execution step
|
# required packages for the pipeline execution step
|
||||||
# Specifying `return_values` makes sure the function step can return an object to the pipeline logic
|
# Specifying `return_values` makes sure the function step can return an object to the pipeline logic
|
||||||
# In this case, the returned object will be stored as an artifact named "model"
|
# In this case, the returned object will be stored as an artifact named "model"
|
||||||
@PipelineDecorator.component(return_values=['model'], cache=True, task_type=TaskTypes.training)
|
@PipelineDecorator.component(return_values=["model"], cache=True, task_type=TaskTypes.training)
|
||||||
def step_three(X_train, y_train):
|
def step_three(X_train, y_train):
|
||||||
print('step_three')
|
print("step_three")
|
||||||
# make sure we have pandas for this step, we need it to use the data_frame
|
# make sure we have pandas for this step, we need it to use the data_frame
|
||||||
import pandas as pd # noqa
|
import pandas as pd # noqa
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
model = LogisticRegression(solver='liblinear', multi_class='auto')
|
|
||||||
|
model = LogisticRegression(solver="liblinear", multi_class="auto")
|
||||||
model.fit(X_train, y_train)
|
model.fit(X_train, y_train)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
# Make the following function an independent pipeline component step
|
# Make the following function an independent pipeline component step
|
||||||
# notice all package imports inside the function will be automatically logged as
|
# notice all package imports inside the function will be automatically logged as
|
||||||
# required packages for the pipeline execution step
|
# required packages for the pipeline execution step
|
||||||
# Specifying `return_values` makes sure the function step can return an object to the pipeline logic
|
# Specifying `return_values` makes sure the function step can return an object to the pipeline logic
|
||||||
# In this case, the returned object will be stored as an artifact named "accuracy"
|
# In this case, the returned object will be stored as an artifact named "accuracy"
|
||||||
@PipelineDecorator.component(return_values=['accuracy'], cache=True, task_type=TaskTypes.qc)
|
@PipelineDecorator.component(return_values=["accuracy"], cache=True, task_type=TaskTypes.qc)
|
||||||
def step_four(model, X_data, Y_data):
|
def step_four(model, X_data, Y_data):
|
||||||
from sklearn.linear_model import LogisticRegression # noqa
|
from sklearn.linear_model import LogisticRegression # noqa
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
|
|
||||||
Y_pred = model.predict(X_data)
|
Y_pred = model.predict(X_data)
|
||||||
return accuracy_score(Y_data, Y_pred, normalize=True)
|
return accuracy_score(Y_data, Y_pred, normalize=True)
|
||||||
|
|
||||||
|
|
||||||
# The actual pipeline execution context
|
# The actual pipeline execution context
|
||||||
# notice that all pipeline component function calls are actually executed remotely
|
# notice that all pipeline component function calls are actually executed remotely
|
||||||
# Only when a return value is used, the pipeline logic will wait for the component execution to complete
|
# Only when a return value is used, the pipeline logic will wait for the component execution to complete
|
||||||
@PipelineDecorator.pipeline(name='custom pipeline logic', project='examples', version='0.0.5')
|
@PipelineDecorator.pipeline(name="custom pipeline logic", project="examples", version="0.0.5")
|
||||||
def executing_pipeline(pickle_url, mock_parameter='mock'):
|
def executing_pipeline(pickle_url, mock_parameter="mock"):
|
||||||
print('pipeline args:', pickle_url, mock_parameter)
|
print("pipeline args:", pickle_url, mock_parameter)
|
||||||
|
|
||||||
# Use the pipeline argument to start the pipeline and pass it ot the first step
|
# Use the pipeline argument to start the pipeline and pass it ot the first step
|
||||||
print('launch step one')
|
print("launch step one")
|
||||||
data_frame = step_one(pickle_url)
|
data_frame = step_one(pickle_url)
|
||||||
|
|
||||||
# Use the returned data from the first step (`step_one`), and pass it to the next step (`step_two`)
|
# Use the returned data from the first step (`step_one`), and pass it to the next step (`step_two`)
|
||||||
@ -87,17 +91,17 @@ def executing_pipeline(pickle_url, mock_parameter='mock'):
|
|||||||
# the pipeline logic does not actually load the artifact itself.
|
# the pipeline logic does not actually load the artifact itself.
|
||||||
# When actually passing the `data_frame` object into a new step,
|
# When actually passing the `data_frame` object into a new step,
|
||||||
# It waits for the creating step/function (`step_one`) to complete the execution
|
# It waits for the creating step/function (`step_one`) to complete the execution
|
||||||
print('launch step two')
|
print("launch step two")
|
||||||
X_train, X_test, y_train, y_test = step_two(data_frame)
|
X_train, X_test, y_train, y_test = step_two(data_frame)
|
||||||
|
|
||||||
print('launch step three')
|
print("launch step three")
|
||||||
model = step_three(X_train, y_train)
|
model = step_three(X_train, y_train)
|
||||||
|
|
||||||
# Notice since we are "printing" the `model` object,
|
# Notice since we are "printing" the `model` object,
|
||||||
# we actually deserialize the object from the third step, and thus wait for the third step to complete.
|
# we actually deserialize the object from the third step, and thus wait for the third step to complete.
|
||||||
print('returned model: {}'.format(model))
|
print("returned model: {}".format(model))
|
||||||
|
|
||||||
print('launch step four')
|
print("launch step four")
|
||||||
accuracy = 100 * step_four(model, X_data=X_test, Y_data=y_test)
|
accuracy = 100 * step_four(model, X_data=X_test, Y_data=y_test)
|
||||||
|
|
||||||
# Notice since we are "printing" the `accuracy` object,
|
# Notice since we are "printing" the `accuracy` object,
|
||||||
@ -105,7 +109,7 @@ def executing_pipeline(pickle_url, mock_parameter='mock'):
|
|||||||
print(f"Accuracy={accuracy}%")
|
print(f"Accuracy={accuracy}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# set the pipeline steps default execution queue (per specific step we can override it with the decorator)
|
# set the pipeline steps default execution queue (per specific step we can override it with the decorator)
|
||||||
# PipelineDecorator.set_default_execution_queue('default')
|
# PipelineDecorator.set_default_execution_queue('default')
|
||||||
# Run the pipeline steps as subprocesses on the current machine, great for local executions
|
# Run the pipeline steps as subprocesses on the current machine, great for local executions
|
||||||
@ -114,7 +118,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Start the pipeline execution logic.
|
# Start the pipeline execution logic.
|
||||||
executing_pipeline(
|
executing_pipeline(
|
||||||
pickle_url='https://github.com/allegroai/events/raw/master/odsc20-east/generic/iris_dataset.pkl',
|
pickle_url="https://github.com/allegroai/events/raw/master/odsc20-east/generic/iris_dataset.pkl",
|
||||||
)
|
)
|
||||||
|
|
||||||
print('process completed')
|
print("process completed")
|
||||||
|
Loading…
Reference in New Issue
Block a user