mirror of
https://github.com/clearml/clearml
synced 2025-04-27 17:51:45 +00:00
Add eager decorated pipeline execution
Support pipeline monitoring for scalers/models/artifacts
This commit is contained in:
parent
c2ae7333b8
commit
400c6ec103
@ -2,16 +2,17 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
from threading import Thread, Event, RLock
|
||||
from time import time
|
||||
from typing import Sequence, Optional, Mapping, Callable, Any, List, Dict, Union
|
||||
from typing import Sequence, Optional, Mapping, Callable, Any, List, Dict, Union, Tuple
|
||||
|
||||
from attr import attrib, attrs
|
||||
|
||||
from .job import LocalClearmlJob
|
||||
from .job import LocalClearmlJob, RunningJob
|
||||
from .. import Logger
|
||||
from ..automation import ClearmlJob
|
||||
from ..backend_interface.task.populate import CreateFromFunction
|
||||
from ..backend_interface.util import get_or_create_project, exact_match_regex
|
||||
@ -38,6 +39,8 @@ class PipelineController(object):
|
||||
_reserved_pipeline_names = (_pipeline_step_ref, )
|
||||
_task_project_lookup = {}
|
||||
_clearml_job_class = ClearmlJob
|
||||
_update_execution_plot_interval = 5.*60
|
||||
_monitor_node_interval = 5.*60
|
||||
|
||||
@attrs
|
||||
class Node(object):
|
||||
@ -55,6 +58,23 @@ class PipelineController(object):
|
||||
skip_job = attrib(type=bool, default=False) # if True, this step should be skipped
|
||||
cache_executed_step = attrib(type=bool, default=False) # if True this pipeline step should be cached
|
||||
return_artifacts = attrib(type=list, default=[]) # List of artifact names returned by the step
|
||||
monitor_metrics = attrib(type=list, default=[]) # List of metric title/series to monitor
|
||||
monitor_artifacts = attrib(type=list, default=[]) # List of artifact names to monitor
|
||||
monitor_models = attrib(type=list, default=[]) # List of models to monitor
|
||||
|
||||
def copy(self):
|
||||
# type: () -> PipelineController.Node
|
||||
"""
|
||||
return a copy of the current Node, excluding the `job`, `executed`, fields
|
||||
:return: new Node copy
|
||||
"""
|
||||
new_copy = PipelineController.Node(
|
||||
name=self.name,
|
||||
**dict((k, deepcopy(v)) for k, v in self.__dict__.items()
|
||||
if k not in ('name', 'job', 'executed', 'task_factory_func'))
|
||||
)
|
||||
new_copy.task_factory_func = self.task_factory_func
|
||||
return new_copy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -108,6 +128,7 @@ class PipelineController(object):
|
||||
self._reporting_lock = RLock()
|
||||
self._pipeline_task_status_failed = None
|
||||
self._auto_version_bump = bool(auto_version_bump)
|
||||
self._mock_execution = False # used for nested pipelines (eager execution)
|
||||
if not self._task:
|
||||
self._task = Task.init(
|
||||
project_name=project or 'Pipelines',
|
||||
@ -124,6 +145,8 @@ class PipelineController(object):
|
||||
if self._task:
|
||||
self._task.add_tags([self._tag])
|
||||
|
||||
self._monitored_nodes = {} # type: Dict[str, dict]
|
||||
|
||||
def set_default_execution_queue(self, default_execution_queue):
|
||||
# type: (Optional[str]) -> None
|
||||
"""
|
||||
@ -151,6 +174,9 @@ class PipelineController(object):
|
||||
parameter_override=None, # type: Optional[Mapping[str, Any]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, Any]]
|
||||
execution_queue=None, # type: Optional[str]
|
||||
monitor_metrics=None, # type: Optional[List[Union[Tuple[str, str], Tuple[(str, str), (str, str)]]]]
|
||||
monitor_artifacts=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
|
||||
monitor_models=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
|
||||
time_limit=None, # type: Optional[float]
|
||||
base_task_project=None, # type: Optional[str]
|
||||
base_task_name=None, # type: Optional[str]
|
||||
@ -193,6 +219,27 @@ class PipelineController(object):
|
||||
parameter_override={'container.image': '${stage1.container.image}' }
|
||||
:param execution_queue: Optional, the queue to use for executing this specific step.
|
||||
If not provided, the task will be sent to the default execution queue, as defined on the class
|
||||
:param monitor_metrics: Optional, log the step's metrics on the pipeline Task.
|
||||
Format is a list of pairs metric (title, series) to log:
|
||||
[(step_metric_title, step_metric_series), ]
|
||||
Example: [('test', 'accuracy'), ]
|
||||
Or a list of tuple pairs, to specify a different target metric for to use on the pipeline Task:
|
||||
[((step_metric_title, step_metric_series), (target_metric_title, target_metric_series)), ]
|
||||
Example: [[('test', 'accuracy'), ('model', 'accuracy')], ]
|
||||
:param monitor_artifacts: Optional, log the step's artifacts on the pipeline Task.
|
||||
Provided a list of artifact names existing on the step's Task, they will also appear on the Pipeline itself.
|
||||
Example: [('processed_data', 'final_processed_data'), ]
|
||||
Alternatively user can also provide a list of artifacts to monitor
|
||||
(target artifact name will be the same as original artifact name)
|
||||
Example: ['processed_data', ]
|
||||
:param monitor_models: Optional, log the step's output models on the pipeline Task.
|
||||
Provided a list of model names existing on the step's Task, they will also appear on the Pipeline itself.
|
||||
Example: [('model_weights', 'final_model_weights'), ]
|
||||
Alternatively user can also provide a list of models to monitor
|
||||
(target models name will be the same as original model)
|
||||
Example: ['model_weights', ]
|
||||
To select the latest (lexicographic) model use "model_*", or the last created model with just "*"
|
||||
Example: ['model_weights_*', ]
|
||||
:param time_limit: Default None, no time limit.
|
||||
Step execution time limit, if exceeded the Task is aborted and the pipeline is stopped and marked failed.
|
||||
:param base_task_project: If base_task_id is not given,
|
||||
@ -288,6 +335,9 @@ class PipelineController(object):
|
||||
task_overrides=task_overrides,
|
||||
cache_executed_step=cache_executed_step,
|
||||
task_factory_func=base_task_factory,
|
||||
monitor_metrics=monitor_metrics or [],
|
||||
monitor_artifacts=monitor_artifacts or [],
|
||||
monitor_models=monitor_models or [],
|
||||
)
|
||||
|
||||
if self._task and not self._task.running_locally():
|
||||
@ -308,11 +358,15 @@ class PipelineController(object):
|
||||
repo=None, # type: Optional[str]
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None, # type: Optional[str]
|
||||
helper_functions=None, # type: Optional[Sequence[Callable]]
|
||||
docker=None, # type: Optional[str]
|
||||
docker_args=None, # type: Optional[str]
|
||||
docker_bash_setup_script=None, # type: Optional[str]
|
||||
parents=None, # type: Optional[Sequence[str]],
|
||||
execution_queue=None, # type: Optional[str]
|
||||
monitor_metrics=None, # type: Optional[List[Union[Tuple[str, str], Tuple[(str, str), (str, str)]]]]
|
||||
monitor_artifacts=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
|
||||
monitor_models=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
|
||||
time_limit=None, # type: Optional[float]
|
||||
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
||||
@ -368,6 +422,8 @@ class PipelineController(object):
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit id (Ignored, if local repo path is used)
|
||||
:param helper_functions: Optional, a list of helper functions to make available
|
||||
for the standalone function Task.
|
||||
:param docker: Select the docker image to be executed in by the remote session
|
||||
:param docker_args: Add docker arguments, pass a single string
|
||||
:param docker_bash_setup_script: Add bash script to be executed
|
||||
@ -377,6 +433,27 @@ class PipelineController(object):
|
||||
have been executed successfully.
|
||||
:param execution_queue: Optional, the queue to use for executing this specific step.
|
||||
If not provided, the task will be sent to the default execution queue, as defined on the class
|
||||
:param monitor_metrics: Optional, log the step's metrics on the pipeline Task.
|
||||
Format is a list of pairs metric (title, series) to log:
|
||||
[(step_metric_title, step_metric_series), ]
|
||||
Example: [('test', 'accuracy'), ]
|
||||
Or a list of tuple pairs, to specify a different target metric for to use on the pipeline Task:
|
||||
[((step_metric_title, step_metric_series), (target_metric_title, target_metric_series)), ]
|
||||
Example: [[('test', 'accuracy'), ('model', 'accuracy')], ]
|
||||
:param monitor_artifacts: Optional, log the step's artifacts on the pipeline Task.
|
||||
Provided a list of artifact names existing on the step's Task, they will also appear on the Pipeline itself.
|
||||
Example: [('processed_data', 'final_processed_data'), ]
|
||||
Alternatively user can also provide a list of artifacts to monitor
|
||||
(target artifact name will be the same as original artifact name)
|
||||
Example: ['processed_data', ]
|
||||
:param monitor_models: Optional, log the step's output models on the pipeline Task.
|
||||
Provided a list of model names existing on the step's Task, they will also appear on the Pipeline itself.
|
||||
Example: [('model_weights', 'final_model_weights'), ]
|
||||
Alternatively user can also provide a list of models to monitor
|
||||
(target models name will be the same as original model)
|
||||
Example: ['model_weights', ]
|
||||
To select the latest (lexicographic) model use "model_*", or the last created model with just "*"
|
||||
Example: ['model_weights_*', ]
|
||||
:param time_limit: Default None, no time limit.
|
||||
Step execution time limit, if exceeded the Task is aborted and the pipeline is stopped and marked failed.
|
||||
:param pre_execute_callback: Callback function, called when the step (Task) is created
|
||||
@ -447,19 +524,31 @@ class PipelineController(object):
|
||||
for k, v in function_input_artifacts.items()}
|
||||
)
|
||||
|
||||
if self._task.running_locally():
|
||||
if self._mock_execution:
|
||||
project_name = project_name or self._target_project or self._task.get_project_name()
|
||||
|
||||
task_definition = self._create_task_from_function(docker, docker_args, docker_bash_setup_script, function,
|
||||
function_input_artifacts, function_kwargs,
|
||||
function_return, packages, project_name, task_name,
|
||||
task_type, repo, repo_branch, repo_commit)
|
||||
task_definition = self._create_task_from_function(
|
||||
docker, docker_args, docker_bash_setup_script, function,
|
||||
function_input_artifacts, function_kwargs,
|
||||
function_return, packages, project_name, task_name,
|
||||
task_type, repo, repo_branch, repo_commit, helper_functions)
|
||||
|
||||
elif self._task.running_locally():
|
||||
project_name = project_name or self._target_project or self._task.get_project_name()
|
||||
|
||||
task_definition = self._create_task_from_function(
|
||||
docker, docker_args, docker_bash_setup_script, function,
|
||||
function_input_artifacts, function_kwargs,
|
||||
function_return, packages, project_name, task_name,
|
||||
task_type, repo, repo_branch, repo_commit, helper_functions)
|
||||
# update configuration with the task definitions
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_configuration(
|
||||
name=name, config_type='json',
|
||||
config_text=json.dumps(task_definition, indent=1)
|
||||
)
|
||||
else:
|
||||
# load task definition from configuration
|
||||
# noinspection PyProtectedMember
|
||||
task_definition = json.loads(self._task._get_configuration_text(name=name))
|
||||
|
||||
@ -481,38 +570,16 @@ class PipelineController(object):
|
||||
cache_executed_step=cache_executed_step,
|
||||
task_factory_func=_create_task,
|
||||
return_artifacts=function_return,
|
||||
monitor_artifacts=monitor_artifacts,
|
||||
monitor_metrics=monitor_metrics,
|
||||
monitor_models=monitor_models,
|
||||
)
|
||||
|
||||
if self._task and not self._task.running_locally():
|
||||
if self._task and not self._task.running_locally() and not self._mock_execution:
|
||||
self.update_execution_plot()
|
||||
|
||||
return True
|
||||
|
||||
def _create_task_from_function(
|
||||
self, docker, docker_args, docker_bash_setup_script,
|
||||
function, function_input_artifacts, function_kwargs, function_return,
|
||||
packages, project_name, task_name, task_type, repo, branch, commit
|
||||
):
|
||||
task_definition = CreateFromFunction.create_task_from_function(
|
||||
a_function=function,
|
||||
function_kwargs=function_kwargs or None,
|
||||
function_input_artifacts=function_input_artifacts,
|
||||
function_return=function_return,
|
||||
project_name=project_name,
|
||||
task_name=task_name,
|
||||
task_type=task_type,
|
||||
repo=repo,
|
||||
branch=branch,
|
||||
commit=commit,
|
||||
packages=packages,
|
||||
docker=docker,
|
||||
docker_args=docker_args,
|
||||
docker_bash_setup_script=docker_bash_setup_script,
|
||||
output_uri=None,
|
||||
dry_run=True,
|
||||
)
|
||||
return task_definition
|
||||
|
||||
def start(
|
||||
self,
|
||||
queue='services',
|
||||
@ -522,7 +589,7 @@ class PipelineController(object):
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
Start the current pipeline remotely (on the selected services queue)
|
||||
Start the current pipeline remotely (on the selected services queue)
|
||||
The current process will be stopped if exit_process is True.
|
||||
|
||||
:param queue: queue name to launch the pipeline on
|
||||
@ -611,6 +678,97 @@ class PipelineController(object):
|
||||
|
||||
self._start(wait=True)
|
||||
|
||||
def create_draft(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Optional, manually create & serialize the Pipeline Task.
|
||||
After calling Pipeline.create(), users can edit the pipeline in the UI and enqueue it for execution.
|
||||
|
||||
Notice: this function should be used to programmatically create pipeline for later usage.
|
||||
To automatically create and launch pipelines, call the `start()` method.
|
||||
"""
|
||||
self._verify()
|
||||
self._serialize_pipeline_task()
|
||||
self._task.close()
|
||||
self._task.reset()
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls):
|
||||
# type: () -> Logger
|
||||
"""
|
||||
Return a logger connected to the Pipeline Task.
|
||||
The logger can be used by any function/tasks executed by the pipeline, in order to report
|
||||
directly to the pipeline Task itself. It can also be called from the main pipeline control Task.
|
||||
|
||||
Raise ValueError if main Pipeline task could not be located.
|
||||
|
||||
:return: Logger object for reporting metrics (scalars, plots, debug samples etc.)
|
||||
"""
|
||||
return cls._get_pipeline_task().get_logger()
|
||||
|
||||
@classmethod
|
||||
def upload_artifact(
|
||||
cls,
|
||||
name, # type: str
|
||||
artifact_object, # type: Any
|
||||
metadata=None, # type: Optional[Mapping]
|
||||
delete_after_upload=False, # type: bool
|
||||
auto_pickle=True, # type: bool
|
||||
preview=None, # type: Any
|
||||
wait_on_upload=False, # type: bool
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
Upload (add) an artifact to the main Pipeline Task object.
|
||||
|
||||
The artifact can be uploaded by any function/tasks executed by the pipeline, in order to report
|
||||
directly to the pipeline Task itself. It can also be called from the main pipeline control Task.
|
||||
|
||||
Raise ValueError if main Pipeline task could not be located.
|
||||
|
||||
The currently supported upload artifact types include:
|
||||
- string / Path - A path to artifact file. If a wildcard or a folder is specified, then ClearML
|
||||
creates and uploads a ZIP file.
|
||||
- dict - ClearML stores a dictionary as ``.json`` file and uploads it.
|
||||
- pandas.DataFrame - ClearML stores a pandas.DataFrame as ``.csv.gz`` (compressed CSV) file and uploads it.
|
||||
- numpy.ndarray - ClearML stores a numpy.ndarray as ``.npz`` file and uploads it.
|
||||
- PIL.Image - ClearML stores a PIL.Image as ``.png`` file and uploads it.
|
||||
- Any - If called with auto_pickle=True, the object will be pickled and uploaded.
|
||||
|
||||
:param str name: The artifact name.
|
||||
|
||||
.. warning::
|
||||
If an artifact with the same name was previously uploaded, then it is overwritten.
|
||||
|
||||
:param object artifact_object: The artifact object.
|
||||
:param dict metadata: A dictionary of key-value pairs for any metadata. This dictionary appears with the
|
||||
experiment in the **ClearML Web-App (UI)**, **ARTIFACTS** tab.
|
||||
:param bool delete_after_upload: After the upload, delete the local copy of the artifact
|
||||
|
||||
- ``True`` - Delete the local copy of the artifact.
|
||||
- ``False`` - Do not delete. (default)
|
||||
|
||||
:param bool auto_pickle: If True (default) and the artifact_object is not one of the following types:
|
||||
pathlib2.Path, dict, pandas.DataFrame, numpy.ndarray, PIL.Image, url (string), local_file (string)
|
||||
the artifact_object will be pickled and uploaded as pickle file artifact (with file extension .pkl)
|
||||
|
||||
:param Any preview: The artifact preview
|
||||
|
||||
:param bool wait_on_upload: Whether or not the upload should be synchronous, forcing the upload to complete
|
||||
before continuing.
|
||||
|
||||
:return: The status of the upload.
|
||||
|
||||
- ``True`` - Upload succeeded.
|
||||
- ``False`` - Upload failed.
|
||||
|
||||
:raise: If the artifact object type is not supported, raise a ``ValueError``.
|
||||
"""
|
||||
task = cls._get_pipeline_task()
|
||||
return task.upload_artifact(
|
||||
name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
|
||||
auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload)
|
||||
|
||||
def stop(self, timeout=None):
|
||||
# type: (Optional[float]) -> ()
|
||||
"""
|
||||
@ -732,6 +890,8 @@ class PipelineController(object):
|
||||
"""
|
||||
with self._reporting_lock:
|
||||
self._update_execution_plot()
|
||||
# also trigger node monitor scanning
|
||||
self._scan_monitored_nodes()
|
||||
|
||||
def add_parameter(self, name, default=None, description=None):
|
||||
# type: (str, Optional[Any], Optional[str]) -> None
|
||||
@ -758,6 +918,32 @@ class PipelineController(object):
|
||||
"""
|
||||
return self._pipeline_args
|
||||
|
||||
def _create_task_from_function(
|
||||
self, docker, docker_args, docker_bash_setup_script,
|
||||
function, function_input_artifacts, function_kwargs, function_return,
|
||||
packages, project_name, task_name, task_type, repo, branch, commit, helper_functions
|
||||
):
|
||||
task_definition = CreateFromFunction.create_task_from_function(
|
||||
a_function=function,
|
||||
function_kwargs=function_kwargs or None,
|
||||
function_input_artifacts=function_input_artifacts,
|
||||
function_return=function_return,
|
||||
project_name=project_name,
|
||||
task_name=task_name,
|
||||
task_type=task_type,
|
||||
repo=repo,
|
||||
branch=branch,
|
||||
commit=commit,
|
||||
packages=packages,
|
||||
docker=docker,
|
||||
docker_args=docker_args,
|
||||
docker_bash_setup_script=docker_bash_setup_script,
|
||||
output_uri=None,
|
||||
helper_functions=helper_functions,
|
||||
dry_run=True,
|
||||
)
|
||||
return task_definition
|
||||
|
||||
def _start(
|
||||
self,
|
||||
step_task_created_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||
@ -837,7 +1023,7 @@ class PipelineController(object):
|
||||
self._deserialize(pipeline_dag)
|
||||
# if we continue the pipeline, make sure that we re-execute failed tasks
|
||||
if params['_continue_pipeline_']:
|
||||
for node in self._nodes.values():
|
||||
for node in list(self._nodes.values()):
|
||||
if node.executed is False:
|
||||
node.executed = None
|
||||
if not self._verify():
|
||||
@ -943,7 +1129,7 @@ class PipelineController(object):
|
||||
only_fields=['id', 'hyperparams', 'runtime'],
|
||||
)
|
||||
found_match_version = False
|
||||
existing_versions = set([self._version])
|
||||
existing_versions = set([self._version]) # noqa
|
||||
for t in existing_tasks:
|
||||
if not t.hyperparams:
|
||||
continue
|
||||
@ -990,7 +1176,7 @@ class PipelineController(object):
|
||||
"""
|
||||
dag = {name: dict((k, v) for k, v in node.__dict__.items()
|
||||
if k not in ('job', 'name', 'task_factory_func'))
|
||||
for name, node in self._nodes.items()}
|
||||
for name, node in list(self._nodes.items())}
|
||||
|
||||
return dag
|
||||
|
||||
@ -1003,7 +1189,7 @@ class PipelineController(object):
|
||||
"""
|
||||
|
||||
# if we do not clone the Task, only merge the parts we can override.
|
||||
for name in self._nodes:
|
||||
for name in list(self._nodes.keys()):
|
||||
if not self._nodes[name].clone_task and name in dag_dict and not dag_dict[name].get('clone_task'):
|
||||
for k in ('queue', 'parents', 'timeout', 'parameters', 'task_overrides'):
|
||||
setattr(self._nodes[name], k, dag_dict[name].get(k) or type(getattr(self._nodes[name], k))())
|
||||
@ -1034,7 +1220,7 @@ class PipelineController(object):
|
||||
:return: return True iff DAG has no errors
|
||||
"""
|
||||
# verify nodes
|
||||
for node in self._nodes.values():
|
||||
for node in list(self._nodes.values()):
|
||||
# raise value error if not verified
|
||||
self._verify_node(node)
|
||||
|
||||
@ -1084,6 +1270,43 @@ class PipelineController(object):
|
||||
'Node "{}" missing parent reference, adding: {}'.format(node.name, parents))
|
||||
node.parents = (node.parents or []) + list(parents)
|
||||
|
||||
# verify and fix monitoring sections:
|
||||
def _verify_monitors(monitors, monitor_type, nested_pairs=False):
|
||||
if not monitors:
|
||||
return monitors
|
||||
|
||||
if nested_pairs:
|
||||
if not all(isinstance(x, (list, tuple)) and x for x in monitors):
|
||||
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||
# convert single pair into a pair of pairs:
|
||||
conformed_monitors = [
|
||||
pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors
|
||||
]
|
||||
# verify pair of pairs
|
||||
if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and
|
||||
isinstance(x[1][0], str) and isinstance(x[1][1], str) for x in conformed_monitors):
|
||||
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||
else:
|
||||
# verify a list of tuples
|
||||
if not all(isinstance(x, (list, tuple, str)) and x for x in monitors):
|
||||
raise ValueError(
|
||||
"{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||
# convert single str into a pair of pairs:
|
||||
conformed_monitors = [
|
||||
pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors
|
||||
]
|
||||
# verify pair of pairs
|
||||
if not all(isinstance(x[0], str) and isinstance(x[1], str) for x in conformed_monitors):
|
||||
raise ValueError(
|
||||
"{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||
|
||||
return conformed_monitors
|
||||
|
||||
# verify and fix monitoring sections:
|
||||
node.monitor_metrics = _verify_monitors(node.monitor_metrics, 'monitor_metrics', nested_pairs=True)
|
||||
node.monitor_artifacts = _verify_monitors(node.monitor_artifacts, 'monitor_artifacts')
|
||||
node.monitor_models = _verify_monitors(node.monitor_models, 'monitor_models')
|
||||
|
||||
return True
|
||||
|
||||
def _verify_dag(self):
|
||||
@ -1095,7 +1318,7 @@ class PipelineController(object):
|
||||
prev_visited = None
|
||||
while prev_visited != visited:
|
||||
prev_visited = copy(visited)
|
||||
for k, node in self._nodes.items():
|
||||
for k, node in list(self._nodes.items()):
|
||||
if k in visited:
|
||||
continue
|
||||
if any(p == node.name for p in node.parents or []):
|
||||
@ -1212,7 +1435,10 @@ class PipelineController(object):
|
||||
visited.append(node.name)
|
||||
idx = len(visited) - 1
|
||||
parents = [visited.index(p) for p in node.parents or []]
|
||||
node_params.append(node.job.task_parameter_override if node.job else node.parameters) or {}
|
||||
node_params.append(
|
||||
(node.job.task_parameter_override
|
||||
if node.job and node.job.task_parameter_override
|
||||
else node.parameters) or {})
|
||||
# sankey_node['label'].append(node.name)
|
||||
# sankey_node['customdata'].append(
|
||||
# '<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
|
||||
@ -1299,7 +1525,7 @@ class PipelineController(object):
|
||||
table_values = [["Pipeline Step", "Task ID", "Task Name", "Status", "Parameters"]]
|
||||
|
||||
for name, param in zip(visited, node_params):
|
||||
param_str = str(param)
|
||||
param_str = str(param) if param else ''
|
||||
if len(param_str) > 3:
|
||||
# remove {} from string
|
||||
param_str = param_str[1:-1]
|
||||
@ -1342,6 +1568,8 @@ class PipelineController(object):
|
||||
elif node.job:
|
||||
if node.job.is_pending():
|
||||
return "#bdf5bd" # lightgreen, pending in queue
|
||||
elif node.job.is_completed():
|
||||
return "blue" # completed job
|
||||
else:
|
||||
return "green" # running job
|
||||
elif node.skip_job:
|
||||
@ -1365,7 +1593,7 @@ class PipelineController(object):
|
||||
"""
|
||||
pooling_counter = 0
|
||||
launched_nodes = set()
|
||||
last_plot_report = time()
|
||||
last_monitor_report = last_plot_report = time()
|
||||
while self._stop_event:
|
||||
# stop request
|
||||
if self._stop_event.wait(self._pool_frequency if pooling_counter else 0.01):
|
||||
@ -1408,9 +1636,13 @@ class PipelineController(object):
|
||||
# nothing changed, we can sleep
|
||||
if not completed_jobs and self._running_nodes:
|
||||
# force updating the pipeline state (plot) at least every 5 min.
|
||||
if force_execution_plot_update or time()-last_plot_report > 5.*60:
|
||||
if force_execution_plot_update or time()-last_plot_report > self._update_execution_plot_interval:
|
||||
last_plot_report = time()
|
||||
last_monitor_report = time()
|
||||
self.update_execution_plot()
|
||||
elif time()-last_monitor_report > self._monitor_node_interval:
|
||||
last_monitor_report = time()
|
||||
self._scan_monitored_nodes()
|
||||
continue
|
||||
|
||||
# callback on completed jobs
|
||||
@ -1426,7 +1658,7 @@ class PipelineController(object):
|
||||
|
||||
# Pull the next jobs in the pipeline, based on the completed list
|
||||
next_nodes = []
|
||||
for node in self._nodes.values():
|
||||
for node in list(self._nodes.values()):
|
||||
# check if already processed or needs to be skipped
|
||||
if node.job or node.executed or node.skip_job:
|
||||
continue
|
||||
@ -1461,7 +1693,7 @@ class PipelineController(object):
|
||||
break
|
||||
|
||||
# stop all currently running jobs:
|
||||
for node in self._nodes.values():
|
||||
for node in list(self._nodes.values()):
|
||||
if node.executed is False:
|
||||
self._pipeline_task_status_failed = True
|
||||
if node.job and node.executed and not node.job.is_stopped():
|
||||
@ -1517,6 +1749,129 @@ class PipelineController(object):
|
||||
if name in self._reserved_pipeline_names:
|
||||
raise ValueError('Node named \'{}\' is a reserved keyword, use a different name'.format(name))
|
||||
|
||||
def _scan_monitored_nodes(self):
|
||||
# type: () -> None
|
||||
"""
|
||||
Scan all nodes and monitor their metrics/artifacts/models
|
||||
"""
|
||||
for node in list(self._nodes.values()):
|
||||
self._monitor_node(node)
|
||||
|
||||
def _monitor_node(self, node):
|
||||
# type: (PipelineController.Node) -> None
|
||||
"""
|
||||
If Node is running, put the metrics from the node on the pipeline itself.
|
||||
:param node: Node to test
|
||||
"""
|
||||
if not node:
|
||||
return
|
||||
|
||||
# verify we have the node
|
||||
if node.name not in self._monitored_nodes:
|
||||
self._monitored_nodes[node.name] = {}
|
||||
|
||||
# if we are done with this node, skip it
|
||||
if self._monitored_nodes[node.name].get('completed'):
|
||||
return
|
||||
|
||||
if node.job and node.job.task:
|
||||
task = node.job.task
|
||||
elif node.job and node.executed and isinstance(node.executed, str):
|
||||
task = Task.get_task(task_id=node.executed)
|
||||
else:
|
||||
return
|
||||
|
||||
# update the metrics
|
||||
if node.monitor_metrics:
|
||||
metrics_state = self._monitored_nodes[node.name].get('metrics', {})
|
||||
logger = self._task.get_logger()
|
||||
scalars = task.get_reported_scalars(x_axis='iter')
|
||||
for (s_title, s_series), (t_title, t_series) in node.monitor_metrics:
|
||||
values = scalars.get(s_title, {}).get(s_series)
|
||||
if values and values.get('x') is not None and values.get('y') is not None:
|
||||
x = values['x'][-1]
|
||||
y = values['y'][-1]
|
||||
last_y = metrics_state.get(s_title, {}).get(s_series)
|
||||
if last_y is None or y > last_y:
|
||||
logger.report_scalar(title=t_title, series=t_series, value=y, iteration=int(x))
|
||||
last_y = y
|
||||
if not metrics_state.get(s_title):
|
||||
metrics_state[s_title] = {}
|
||||
metrics_state[s_title][s_series] = last_y
|
||||
|
||||
self._monitored_nodes[node.name]['metrics'] = metrics_state
|
||||
|
||||
if node.monitor_artifacts:
|
||||
task.reload()
|
||||
artifacts = task.data.execution.artifacts
|
||||
self._task.reload()
|
||||
output_artifacts = []
|
||||
for s_artifact, t_artifact in node.monitor_artifacts:
|
||||
# find artifact
|
||||
for a in artifacts:
|
||||
if a.key != s_artifact:
|
||||
continue
|
||||
|
||||
new_a = copy(a)
|
||||
new_a.key = t_artifact
|
||||
output_artifacts.append(new_a)
|
||||
break
|
||||
|
||||
# update artifacts directly on the Task
|
||||
if output_artifacts:
|
||||
# noinspection PyProtectedMember
|
||||
self._task._add_artifacts(output_artifacts)
|
||||
|
||||
if node.monitor_models:
|
||||
task.reload()
|
||||
output_models = task.data.models.output
|
||||
self._task.reload()
|
||||
target_models = []
|
||||
for s_model, t_model in node.monitor_models:
|
||||
# find artifact
|
||||
for a in output_models:
|
||||
if a.name != s_model:
|
||||
continue
|
||||
|
||||
new_a = copy(a)
|
||||
new_a.name = t_model
|
||||
target_models.append(new_a)
|
||||
break
|
||||
|
||||
# update artifacts directly on the Task
|
||||
if target_models:
|
||||
self._task.reload()
|
||||
models = self._task.data.models
|
||||
keys = [a.name for a in models.output]
|
||||
models.output = [a for a in models.output or [] if a.name not in keys] + target_models
|
||||
# noinspection PyProtectedMember
|
||||
self._task._edit(models=models)
|
||||
|
||||
# update the state (so that we do not scan the node twice)
|
||||
if node.job.is_stopped():
|
||||
self._monitored_nodes[node.name]['completed'] = True
|
||||
|
||||
@classmethod
|
||||
def _get_pipeline_task(cls):
|
||||
# type: () -> Task
|
||||
"""
|
||||
Return the pipeline Task (either the current one, or the parent Task of the currently running Task)
|
||||
Raise ValueError if we could not locate the pipeline Task
|
||||
|
||||
:return: Pipeline Task
|
||||
"""
|
||||
# get main Task.
|
||||
task = Task.current_task()
|
||||
if str(task.task_type) == str(Task.TaskTypes.controller) and cls._tag in task.get_system_tags():
|
||||
return task
|
||||
# get the parent Task, it should be the pipeline
|
||||
if not task.parent:
|
||||
raise ValueError("Could not locate parent Pipeline Task")
|
||||
parent = Task.get_task(task_id=task.parent)
|
||||
if str(parent.task_type) == str(Task.TaskTypes.controller) and cls._tag in parent.get_system_tags():
|
||||
return parent
|
||||
raise ValueError("Could not locate parent Pipeline Task")
|
||||
|
||||
def __verify_step_reference(self, node, step_ref_string):
|
||||
# type: (PipelineController.Node, str) -> Optional[str]
|
||||
"""
|
||||
@ -1717,6 +2072,8 @@ class PipelineController(object):
|
||||
class PipelineDecorator(PipelineController):
|
||||
_added_decorator = [] # type: List[dict]
|
||||
_singleton = None # type: Optional[PipelineDecorator]
|
||||
_eager_step_artifact = 'eager_step'
|
||||
_eager_execution_instance = False
|
||||
_debug_execute_step_process = False
|
||||
_debug_execute_step_function = False
|
||||
_default_execution_queue = None
|
||||
@ -1751,6 +2108,11 @@ class PipelineDecorator(PipelineController):
|
||||
add_pipeline_tags=add_pipeline_tags,
|
||||
target_project=target_project,
|
||||
)
|
||||
|
||||
# if we are in eager execution, make sure parent class knows it
|
||||
if self._eager_execution_instance:
|
||||
self._mock_execution = True
|
||||
|
||||
if PipelineDecorator._default_execution_queue:
|
||||
super(PipelineDecorator, self).set_default_execution_queue(
|
||||
PipelineDecorator._default_execution_queue)
|
||||
@ -1760,6 +2122,8 @@ class PipelineDecorator(PipelineController):
|
||||
self._added_decorator.clear()
|
||||
PipelineDecorator._singleton = self
|
||||
self._reference_callback = []
|
||||
# map eager steps task id to the new step name
|
||||
self._eager_steps_task_id = {} # type: Dict[str, str]
|
||||
|
||||
def _daemon(self):
|
||||
# type: () -> ()
|
||||
@ -1771,7 +2135,7 @@ class PipelineDecorator(PipelineController):
|
||||
"""
|
||||
pooling_counter = 0
|
||||
launched_nodes = set()
|
||||
last_plot_report = time()
|
||||
last_monitor_report = last_plot_report = time()
|
||||
while self._stop_event:
|
||||
# stop request
|
||||
if self._stop_event.wait(self._pool_frequency if pooling_counter else 0.01):
|
||||
@ -1814,9 +2178,13 @@ class PipelineDecorator(PipelineController):
|
||||
# nothing changed, we can sleep
|
||||
if not completed_jobs and self._running_nodes:
|
||||
# force updating the pipeline state (plot) at least every 5 min.
|
||||
if force_execution_plot_update or time()-last_plot_report > 5.*60:
|
||||
if force_execution_plot_update or time()-last_plot_report > self._update_execution_plot_interval:
|
||||
last_plot_report = time()
|
||||
last_monitor_report = time()
|
||||
self.update_execution_plot()
|
||||
elif time()-last_monitor_report > self._monitor_node_interval:
|
||||
last_monitor_report = time()
|
||||
self._scan_monitored_nodes()
|
||||
continue
|
||||
|
||||
# callback on completed jobs
|
||||
@ -1836,8 +2204,8 @@ class PipelineDecorator(PipelineController):
|
||||
# visualize pipeline state (plot)
|
||||
self.update_execution_plot()
|
||||
|
||||
# stop all currently running jobs:
|
||||
for node in self._nodes.values():
|
||||
# stop all currently running jobs, protect against changes while iterating):
|
||||
for node in list(self._nodes.values()):
|
||||
if node.executed is False:
|
||||
self._pipeline_task_status_failed = True
|
||||
if node.job and node.executed and not node.job.is_stopped():
|
||||
@ -1845,6 +2213,16 @@ class PipelineDecorator(PipelineController):
|
||||
elif not node.job and not node.executed:
|
||||
# mark Node as skipped if it has no Job object and it is not executed
|
||||
node.skip_job = True
|
||||
# if this is a standalone node, we need to remove it from the graph
|
||||
if not node.parents:
|
||||
# check if this node is anyone's parent
|
||||
found_parent = False
|
||||
for v in list(self._nodes.values()):
|
||||
if node.name in (v.parents or []):
|
||||
found_parent = True
|
||||
break
|
||||
if not found_parent:
|
||||
self._nodes.pop(node.name, None)
|
||||
|
||||
# visualize pipeline state (plot)
|
||||
self.update_execution_plot()
|
||||
@ -1856,10 +2234,76 @@ class PipelineDecorator(PipelineController):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_execution_plot(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Update sankey diagram of the current pipeline
|
||||
"""
|
||||
self._update_eager_generated_steps()
|
||||
super(PipelineDecorator, self).update_execution_plot()
|
||||
|
||||
def _update_eager_generated_steps(self):
|
||||
# noinspection PyProtectedMember
|
||||
self._task.reload()
|
||||
artifacts = self._task.data.execution.artifacts
|
||||
# check if we have a new step on the DAG
|
||||
eager_artifacts = []
|
||||
for a in artifacts:
|
||||
if a.key and a.key.startswith('{}:'.format(self._eager_step_artifact)):
|
||||
# expected value: '"eager_step":"parent-node-task-id":"eager-step-task-id'
|
||||
eager_artifacts.append(a)
|
||||
|
||||
# verify we have the step, if we do not, add it.
|
||||
delete_artifact_keys = []
|
||||
for artifact in eager_artifacts:
|
||||
_, parent_step_task_id, eager_step_task_id = artifact.key.split(':', 2)
|
||||
|
||||
# deserialize node definition
|
||||
eager_node_def = json.loads(artifact.type_data.preview)
|
||||
eager_node_name, eager_node_def = list(eager_node_def.items())[0]
|
||||
|
||||
# verify we do not have any new nodes on the DAG (i.e. a step generating a Node eagerly)
|
||||
parent_node = None
|
||||
for node in list(self._nodes.values()):
|
||||
if not node.job and not node.executed:
|
||||
continue
|
||||
t_id = node.executed or node.job.task_id
|
||||
if t_id == parent_step_task_id:
|
||||
parent_node = node
|
||||
break
|
||||
|
||||
if not parent_node:
|
||||
# should not happen
|
||||
continue
|
||||
|
||||
new_step_node_name = '{}_{}'.format(parent_node.name, eager_node_name)
|
||||
counter = 1
|
||||
while new_step_node_name in self._nodes:
|
||||
new_step_node_name = '{}_{}'.format(new_step_node_name, counter)
|
||||
counter += 1
|
||||
|
||||
eager_node_def['name'] = new_step_node_name
|
||||
eager_node_def['parents'] = [parent_node.name]
|
||||
is_cached = eager_node_def.pop('is_cached', None)
|
||||
self._nodes[new_step_node_name] = self.Node(**eager_node_def)
|
||||
self._nodes[new_step_node_name].job = RunningJob(existing_task=eager_step_task_id)
|
||||
if is_cached:
|
||||
self._nodes[new_step_node_name].job.force_set_is_cached(is_cached)
|
||||
|
||||
# make sure we will not rescan it.
|
||||
delete_artifact_keys.append(artifact.key)
|
||||
|
||||
# remove all processed eager step artifacts
|
||||
if delete_artifact_keys:
|
||||
# noinspection PyProtectedMember
|
||||
self._task._delete_artifacts(delete_artifact_keys)
|
||||
self._force_task_configuration_update()
|
||||
|
||||
def _create_task_from_function(
|
||||
self, docker, docker_args, docker_bash_setup_script,
|
||||
function, function_input_artifacts, function_kwargs, function_return,
|
||||
packages, project_name, task_name, task_type, repo, branch, commit,
|
||||
helper_functions,
|
||||
):
|
||||
def sanitize(function_source):
|
||||
matched = re.match(r"[\s]*@PipelineDecorator.component[\s\\]*\(", function_source)
|
||||
@ -1896,6 +2340,7 @@ class PipelineDecorator(PipelineController):
|
||||
docker_args=docker_args,
|
||||
docker_bash_setup_script=docker_bash_setup_script,
|
||||
output_uri=None,
|
||||
helper_functions=helper_functions,
|
||||
dry_run=True,
|
||||
_sanitize_function=sanitize,
|
||||
)
|
||||
@ -1903,8 +2348,8 @@ class PipelineDecorator(PipelineController):
|
||||
|
||||
def _find_executed_node_leaves(self):
|
||||
# type: () -> List[PipelineController.Node]
|
||||
all_parents = set([p for n in self._nodes.values() if n.executed for p in n.parents])
|
||||
executed_leaves = [name for name, n in self._nodes.items() if n.executed and name not in all_parents]
|
||||
all_parents = set([p for n in list(self._nodes.values()) if n.executed for p in n.parents])
|
||||
executed_leaves = [name for name, n in list(self._nodes.items()) if n.executed and name not in all_parents]
|
||||
return executed_leaves
|
||||
|
||||
def _adjust_task_hashing(self, task_hash):
|
||||
@ -1943,7 +2388,11 @@ class PipelineDecorator(PipelineController):
|
||||
task_type=None, # type: Optional[str]
|
||||
repo=None, # type: Optional[str]
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None # type: Optional[str]
|
||||
repo_commit=None, # type: Optional[str]
|
||||
helper_functions=None, # type: Optional[Sequence[Callable]]
|
||||
monitor_metrics=None, # type: Optional[List[Union[Tuple[str, str], Tuple[(str, str), (str, str)]]]]
|
||||
monitor_artifacts=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
|
||||
monitor_models=None, # type: Optional[List[Union[str, Tuple[str, str]]]]
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@ -1981,6 +2430,29 @@ class PipelineDecorator(PipelineController):
|
||||
repo url and commit ID based on the locally cloned copy
|
||||
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param repo_commit: Optional, specify the repository commit id (Ignored, if local repo path is used)
|
||||
:param helper_functions: Optional, a list of helper functions to make available
|
||||
for the standalone pipeline step function Task.
|
||||
:param monitor_metrics: Optional, log the step's metrics on the pipeline Task.
|
||||
Format is a list of pairs metric (title, series) to log:
|
||||
[(step_metric_title, step_metric_series), ]
|
||||
Example: [('test', 'accuracy'), ]
|
||||
Or a list of tuple pairs, to specify a different target metric for to use on the pipeline Task:
|
||||
[((step_metric_title, step_metric_series), (target_metric_title, target_metric_series)), ]
|
||||
Example: [[('test', 'accuracy'), ('model', 'accuracy')], ]
|
||||
:param monitor_artifacts: Optional, log the step's artifacts on the pipeline Task.
|
||||
Provided a list of artifact names existing on the step's Task, they will also appear on the Pipeline itself.
|
||||
Example: [('processed_data', 'final_processed_data'), ]
|
||||
Alternatively user can also provide a list of artifacts to monitor
|
||||
(target artifact name will be the same as original artifact name)
|
||||
Example: ['processed_data', ]
|
||||
:param monitor_models: Optional, log the step's output models on the pipeline Task.
|
||||
Provided a list of model names existing on the step's Task, they will also appear on the Pipeline itself.
|
||||
Example: [('model_weights', 'final_model_weights'), ]
|
||||
Alternatively user can also provide a list of models to monitor
|
||||
(target models name will be the same as original model)
|
||||
Example: ['model_weights', ]
|
||||
To select the latest (lexicographic) model use "model_*", or the last created model with just "*"
|
||||
Example: ['model_weights_*', ]
|
||||
|
||||
:return: function wrapper
|
||||
"""
|
||||
@ -2013,6 +2485,10 @@ class PipelineDecorator(PipelineController):
|
||||
repo=repo,
|
||||
repo_branch=repo_branch,
|
||||
repo_commit=repo_commit,
|
||||
helper_functions=helper_functions,
|
||||
monitor_metrics=monitor_metrics,
|
||||
monitor_models=monitor_models,
|
||||
monitor_artifacts=monitor_artifacts,
|
||||
)
|
||||
|
||||
if cls._singleton:
|
||||
@ -2055,7 +2531,52 @@ class PipelineDecorator(PipelineController):
|
||||
)
|
||||
kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, LazyEvalWrapper)}
|
||||
|
||||
_node = cls._singleton._nodes[_name]
|
||||
# check if we have the singleton
|
||||
if not cls._singleton:
|
||||
# todo: somehow make sure the generated tasks list the parent pipeline as parent
|
||||
original_tags = Task.current_task().get_tags(), Task.current_task().get_system_tags()
|
||||
# This is an adhoc pipeline step,
|
||||
PipelineDecorator._eager_execution_instance = True
|
||||
a_pipeline = PipelineDecorator(
|
||||
name=name,
|
||||
project='DevOps', # it will not actually be used
|
||||
version='0.0.0',
|
||||
pool_frequency=111,
|
||||
add_pipeline_tags=False,
|
||||
target_project=None,
|
||||
)
|
||||
|
||||
target_queue = \
|
||||
PipelineDecorator._default_execution_queue or \
|
||||
Task.current_task().data.execution.queue
|
||||
if target_queue:
|
||||
PipelineDecorator.set_default_execution_queue(target_queue)
|
||||
else:
|
||||
# if we are are not running from a queue, we are probably in debug mode
|
||||
a_pipeline._clearml_job_class = LocalClearmlJob
|
||||
a_pipeline._default_execution_queue = 'mock'
|
||||
|
||||
# restore tags, the pipeline might add a few
|
||||
Task.current_task().set_tags(original_tags[0])
|
||||
Task.current_task().set_system_tags(original_tags[1])
|
||||
|
||||
# get original node name
|
||||
_node_name = _name
|
||||
# get node
|
||||
_node = cls._singleton._nodes[_node_name]
|
||||
|
||||
# if we already have a JOB on the node, this means we are calling the same function/task
|
||||
# twice inside the pipeline, this means we need to replicate the node.
|
||||
if _node.job:
|
||||
_node = _node.copy()
|
||||
# find a new name
|
||||
counter = 1
|
||||
while _node.name in cls._singleton._nodes:
|
||||
_node.name = '{}_{}'.format(_node_name, counter)
|
||||
counter += 1
|
||||
_node_name = _node.name
|
||||
cls._singleton._nodes[_node.name] = _node
|
||||
|
||||
# update artifacts kwargs
|
||||
for k, v in kwargs_artifacts.items():
|
||||
if k in kwargs:
|
||||
@ -2086,8 +2607,32 @@ class PipelineDecorator(PipelineController):
|
||||
set((_node.parents or []) + cls._singleton._find_executed_node_leaves())
|
||||
- set(list(_node.name)))
|
||||
|
||||
# verify the new step
|
||||
cls._singleton._verify_node(_node)
|
||||
# launch the new step
|
||||
cls._singleton._launch_node(_node)
|
||||
# check if we generated the pipeline we need to update the new eager step
|
||||
if PipelineDecorator._eager_execution_instance and _node.job:
|
||||
# store the new generated node, so we can later serialize it
|
||||
pipeline_dag = cls._singleton._serialize()
|
||||
# check if node is cached
|
||||
if _node.job.is_cached_task():
|
||||
pipeline_dag[_node_name]['is_cached'] = True
|
||||
# store entire definition on the parent pipeline
|
||||
from clearml.backend_api.services import tasks
|
||||
artifact = tasks.Artifact(
|
||||
key='{}:{}:{}'.format(cls._eager_step_artifact, Task.current_task().id, _node.job.task_id()),
|
||||
type="json",
|
||||
mode='output',
|
||||
type_data=tasks.ArtifactTypeData(
|
||||
preview=json.dumps({_node_name: pipeline_dag[_node_name]}),
|
||||
content_type='application/pipeline')
|
||||
)
|
||||
req = tasks.AddOrUpdateArtifactsRequest(
|
||||
task=Task.current_task().parent, artifacts=[artifact], force=True)
|
||||
res = Task.current_task().send(req, raise_on_errors=False)
|
||||
if not res or not res.response or not res.response.updated:
|
||||
pass
|
||||
|
||||
def results_reference(return_name):
|
||||
# wait until job is completed
|
||||
@ -2130,7 +2675,7 @@ class PipelineDecorator(PipelineController):
|
||||
pool_frequency=0.2, # type: float
|
||||
add_pipeline_tags=False, # type: bool
|
||||
target_project=None, # type: Optional[str]
|
||||
pipeline_execution_queue='services' # type: Optional[str]
|
||||
pipeline_execution_queue='services', # type: Optional[str]
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@ -2212,6 +2757,16 @@ class PipelineDecorator(PipelineController):
|
||||
# this time the pipeline is executed only on the remote machine
|
||||
func(**pipeline_kwargs)
|
||||
LazyEvalWrapper.trigger_all_remote_references()
|
||||
# make sure we wait for all nodes to finish
|
||||
waited = True
|
||||
while waited:
|
||||
waited = False
|
||||
for node in list(a_pipeline._nodes.values()):
|
||||
if node.executed or not node.job or node.job.is_stopped():
|
||||
continue
|
||||
node.job.wait(pool_period=15)
|
||||
waited = True
|
||||
# now we can stop the pipeline
|
||||
a_pipeline.stop()
|
||||
return
|
||||
|
||||
|
@ -8,7 +8,7 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
from time import time, sleep
|
||||
from typing import Optional, Mapping, Sequence, Any, Callable
|
||||
from typing import Optional, Mapping, Sequence, Any, Callable, Union
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
@ -22,125 +22,21 @@ from ..backend_api.services import tasks as tasks_service
|
||||
logger = getLogger('clearml.automation.job')
|
||||
|
||||
|
||||
class ClearmlJob(object):
|
||||
class BaseJob(object):
|
||||
_job_hash_description = 'job_hash={}'
|
||||
_job_hash_property = 'pipeline_job_hash'
|
||||
_hashing_callback = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
parameter_override=None, # type: Optional[Mapping[str, str]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, str]]
|
||||
tags=None, # type: Optional[Sequence[str]]
|
||||
parent=None, # type: Optional[str]
|
||||
disable_clone_task=False, # type: bool
|
||||
allow_caching=False, # type: bool
|
||||
target_project=None, # type: Optional[str]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> ()
|
||||
def __init__(self):
|
||||
# type: () -> ()
|
||||
"""
|
||||
Create a new Task based in a base_task_id with a different set of parameters
|
||||
|
||||
:param str base_task_id: base task id to clone from
|
||||
:param dict parameter_override: dictionary of parameters and values to set fo the cloned task
|
||||
:param dict task_overrides: Task object specific overrides.
|
||||
for example {'script.version_num': None, 'script.branch': 'main'}
|
||||
:param list tags: additional tags to add to the newly cloned task
|
||||
:param str parent: Set newly created Task parent task field, default: base_tak_id.
|
||||
:param dict kwargs: additional Task creation parameters
|
||||
:param bool disable_clone_task: if False (default) clone base task id.
|
||||
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
||||
:param bool allow_caching: If True check if we have a previously executed Task with the same specification
|
||||
If we do, use it and set internal is_cached flag. Default False (always create new Task).
|
||||
:param str target_project: Optional, Set the target project name to create the cloned Task in.
|
||||
Base Job is an abstract CLearML Job
|
||||
"""
|
||||
base_temp_task = Task.get_task(task_id=base_task_id)
|
||||
if disable_clone_task:
|
||||
self.task = base_temp_task
|
||||
task_status = self.task.status
|
||||
if task_status != Task.TaskStatusEnum.created:
|
||||
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
|
||||
'Reverting to clone Task'.format(base_task_id, task_status))
|
||||
disable_clone_task = False
|
||||
self.task = None
|
||||
elif parent:
|
||||
self.task.set_parent(parent)
|
||||
else:
|
||||
self.task = None
|
||||
|
||||
self.task_parameter_override = None
|
||||
task_params = None
|
||||
if parameter_override:
|
||||
task_params = base_temp_task.get_parameters(backwards_compatibility=False)
|
||||
task_params.update(parameter_override)
|
||||
self.task_parameter_override = dict(**parameter_override)
|
||||
|
||||
sections = {}
|
||||
if task_overrides:
|
||||
# set values inside the Task
|
||||
for k, v in task_overrides.items():
|
||||
# notice we can allow ourselves to change the base-task object as we will not use it any further
|
||||
# noinspection PyProtectedMember
|
||||
base_temp_task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
|
||||
section = k.split('.')[0]
|
||||
sections[section] = getattr(base_temp_task.data, section, None)
|
||||
|
||||
# check cached task
|
||||
self._is_cached_task = False
|
||||
task_hash = None
|
||||
if allow_caching:
|
||||
# look for a cached copy of the Task
|
||||
# get parameters + task_overrides + as dict and hash it.
|
||||
task_hash = self._create_task_hash(
|
||||
base_temp_task, section_overrides=sections, params_override=task_params)
|
||||
task = self._get_cached_task(task_hash)
|
||||
# if we found a task, just use
|
||||
if task:
|
||||
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:
|
||||
# if the base task at is in draft mode, and we are using cached task
|
||||
# we assume the base Task was created adhoc and we can delete it.
|
||||
pass # self.task.delete()
|
||||
|
||||
self._is_cached_task = True
|
||||
self.task = task
|
||||
self.task_started = True
|
||||
self._worker = None
|
||||
return
|
||||
|
||||
# if we have target_project, remove project from kwargs if we have it.
|
||||
if target_project and 'project' in kwargs:
|
||||
logger.info(
|
||||
'target_project={} and project={} passed, using target_project.'.format(
|
||||
target_project, kwargs['project']))
|
||||
kwargs.pop('project', None)
|
||||
|
||||
# check again if we need to clone the Task
|
||||
if not disable_clone_task:
|
||||
# noinspection PyProtectedMember
|
||||
self.task = Task.clone(
|
||||
base_task_id, parent=parent or base_task_id,
|
||||
project=get_or_create_project(
|
||||
session=Task._get_default_session(), project_name=target_project
|
||||
) if target_project else kwargs.pop('project', None),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if tags:
|
||||
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
|
||||
|
||||
if task_params:
|
||||
self.task.set_parameters(task_params)
|
||||
|
||||
if task_overrides and sections:
|
||||
# store back Task parameters into backend
|
||||
# noinspection PyProtectedMember
|
||||
self.task._edit(**sections)
|
||||
|
||||
self._set_task_cache_hash(self.task, task_hash)
|
||||
self.task_started = False
|
||||
self._worker = None
|
||||
self.task_parameter_override = None
|
||||
self.task = None
|
||||
self.task_started = False
|
||||
|
||||
def get_metric(self, title, series):
|
||||
# type: (str, str) -> (float, float, float)
|
||||
@ -505,6 +401,125 @@ class ClearmlJob(object):
|
||||
task.set_comment(task.comment + '\n' + hash_comment if task.comment else hash_comment)
|
||||
|
||||
|
||||
class ClearmlJob(BaseJob):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
parameter_override=None, # type: Optional[Mapping[str, str]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, str]]
|
||||
tags=None, # type: Optional[Sequence[str]]
|
||||
parent=None, # type: Optional[str]
|
||||
disable_clone_task=False, # type: bool
|
||||
allow_caching=False, # type: bool
|
||||
target_project=None, # type: Optional[str]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
Create a new Task based in a base_task_id with a different set of parameters
|
||||
|
||||
:param str base_task_id: base task id to clone from
|
||||
:param dict parameter_override: dictionary of parameters and values to set fo the cloned task
|
||||
:param dict task_overrides: Task object specific overrides.
|
||||
for example {'script.version_num': None, 'script.branch': 'main'}
|
||||
:param list tags: additional tags to add to the newly cloned task
|
||||
:param str parent: Set newly created Task parent task field, default: base_tak_id.
|
||||
:param dict kwargs: additional Task creation parameters
|
||||
:param bool disable_clone_task: if False (default) clone base task id.
|
||||
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
||||
:param bool allow_caching: If True check if we have a previously executed Task with the same specification
|
||||
If we do, use it and set internal is_cached flag. Default False (always create new Task).
|
||||
:param str target_project: Optional, Set the target project name to create the cloned Task in.
|
||||
"""
|
||||
super(ClearmlJob, self).__init__()
|
||||
base_temp_task = Task.get_task(task_id=base_task_id)
|
||||
if disable_clone_task:
|
||||
self.task = base_temp_task
|
||||
task_status = self.task.status
|
||||
if task_status != Task.TaskStatusEnum.created:
|
||||
logger.warning('Task cloning disabled but requested Task [{}] status={}. '
|
||||
'Reverting to clone Task'.format(base_task_id, task_status))
|
||||
disable_clone_task = False
|
||||
self.task = None
|
||||
elif parent:
|
||||
self.task.set_parent(parent)
|
||||
else:
|
||||
self.task = None
|
||||
|
||||
self.task_parameter_override = None
|
||||
task_params = None
|
||||
if parameter_override:
|
||||
task_params = base_temp_task.get_parameters(backwards_compatibility=False)
|
||||
task_params.update(parameter_override)
|
||||
self.task_parameter_override = dict(**parameter_override)
|
||||
|
||||
sections = {}
|
||||
if task_overrides:
|
||||
# set values inside the Task
|
||||
for k, v in task_overrides.items():
|
||||
# notice we can allow ourselves to change the base-task object as we will not use it any further
|
||||
# noinspection PyProtectedMember
|
||||
base_temp_task._set_task_property(k, v, raise_on_error=False, log_on_error=True)
|
||||
section = k.split('.')[0]
|
||||
sections[section] = getattr(base_temp_task.data, section, None)
|
||||
|
||||
# check cached task
|
||||
self._is_cached_task = False
|
||||
task_hash = None
|
||||
if allow_caching:
|
||||
# look for a cached copy of the Task
|
||||
# get parameters + task_overrides + as dict and hash it.
|
||||
task_hash = self._create_task_hash(
|
||||
base_temp_task, section_overrides=sections, params_override=task_params)
|
||||
task = self._get_cached_task(task_hash)
|
||||
# if we found a task, just use
|
||||
if task:
|
||||
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:
|
||||
# if the base task at is in draft mode, and we are using cached task
|
||||
# we assume the base Task was created adhoc and we can delete it.
|
||||
pass # self.task.delete()
|
||||
|
||||
self._is_cached_task = True
|
||||
self.task = task
|
||||
self.task_started = True
|
||||
self._worker = None
|
||||
return
|
||||
|
||||
# if we have target_project, remove project from kwargs if we have it.
|
||||
if target_project and 'project' in kwargs:
|
||||
logger.info(
|
||||
'target_project={} and project={} passed, using target_project.'.format(
|
||||
target_project, kwargs['project']))
|
||||
kwargs.pop('project', None)
|
||||
|
||||
# check again if we need to clone the Task
|
||||
if not disable_clone_task:
|
||||
# noinspection PyProtectedMember
|
||||
self.task = Task.clone(
|
||||
base_task_id, parent=parent or base_task_id,
|
||||
project=get_or_create_project(
|
||||
session=Task._get_default_session(), project_name=target_project
|
||||
) if target_project else kwargs.pop('project', None),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if tags:
|
||||
self.task.set_tags(list(set(self.task.get_tags()) | set(tags)))
|
||||
|
||||
if task_params:
|
||||
self.task.set_parameters(task_params)
|
||||
|
||||
if task_overrides and sections:
|
||||
# store back Task parameters into backend
|
||||
# noinspection PyProtectedMember
|
||||
self.task._edit(**sections)
|
||||
|
||||
self._set_task_cache_hash(self.task, task_hash)
|
||||
self.task_started = False
|
||||
self._worker = None
|
||||
|
||||
|
||||
class LocalClearmlJob(ClearmlJob):
|
||||
"""
|
||||
Run jobs locally as a sub-process, use for debugging purposes only
|
||||
@ -546,9 +561,8 @@ class LocalClearmlJob(ClearmlJob):
|
||||
levels = 0
|
||||
if working_dir:
|
||||
levels = 1 + sum(1 for c in working_dir if c == '/')
|
||||
if levels:
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.sep.join(['..'] * levels)))
|
||||
cwd = os.path.join(cwd, self.task.data.script.working_dir)
|
||||
cwd = os.path.abspath(os.path.join(os.getcwd(), os.sep.join(['..'] * levels))) if levels else os.getcwd()
|
||||
cwd = os.path.join(cwd, working_dir)
|
||||
|
||||
python = sys.executable
|
||||
env = dict(**os.environ)
|
||||
@ -593,6 +607,22 @@ class LocalClearmlJob(ClearmlJob):
|
||||
return exit_code
|
||||
|
||||
|
||||
class RunningJob(BaseJob):
|
||||
"""
|
||||
Wrapper to an already running Task
|
||||
"""
|
||||
|
||||
def __init__(self, existing_task): # noqa
|
||||
# type: (Union[Task, str]) -> None
|
||||
super(RunningJob, self).__init__()
|
||||
self.task = existing_task if isinstance(existing_task, Task) else Task.get_task(task_id=existing_task)
|
||||
self.task_started = bool(self.task.status != Task.TaskStatusEnum.created)
|
||||
|
||||
def force_set_is_cached(self, cached):
|
||||
# type: (bool) -> ()
|
||||
self._is_cached_task = bool(cached)
|
||||
|
||||
|
||||
class TrainsJob(ClearmlJob):
|
||||
"""
|
||||
Deprecated, use ClearmlJob
|
||||
|
@ -461,6 +461,7 @@ class CreateFromFunction(object):
|
||||
kwargs_section = 'kwargs'
|
||||
input_artifact_section = 'kwargs_artifacts'
|
||||
task_template = """from clearml import Task
|
||||
from clearml.automation.controller import PipelineDecorator
|
||||
|
||||
|
||||
{function_source}
|
||||
@ -504,8 +505,10 @@ if __name__ == '__main__':
|
||||
docker_args=None, # type: Optional[str]
|
||||
docker_bash_setup_script=None, # type: Optional[str]
|
||||
output_uri=None, # type: Optional[str]
|
||||
helper_functions=None, # type: Optional[Sequence[Callable]]
|
||||
dry_run=False, # type: bool
|
||||
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
|
||||
):
|
||||
# type: (...) -> Optional[Dict, Task]
|
||||
"""
|
||||
@ -558,14 +561,25 @@ if __name__ == '__main__':
|
||||
inside the docker before setting up the Task's environment
|
||||
:param output_uri: Optional, set the Tasks's output_uri (Storage destination).
|
||||
examples: 's3://bucket/folder', 'https://server/' , 'gs://bucket/folder', 'azure://bucket', '/folder/'
|
||||
:param helper_functions: Optional, a list of helper functions to make available
|
||||
for the standalone function Task.
|
||||
:param dry_run: If True do not create the Task, but return a dict of the Task's definitions
|
||||
:param _sanitize_function: Sanitization function for the function string.
|
||||
:param _sanitize_helper_functions: Sanitization function for the helper function string.
|
||||
:return: Newly created Task object
|
||||
"""
|
||||
function_name = str(a_function.__name__)
|
||||
function_source = inspect.getsource(a_function)
|
||||
if _sanitize_function:
|
||||
function_source = _sanitize_function(function_source)
|
||||
function_source = cls.__sanitize_remove_type_hints(function_source)
|
||||
|
||||
# add helper functions on top.
|
||||
for f in (helper_functions or []):
|
||||
f_source = inspect.getsource(f)
|
||||
if _sanitize_helper_functions:
|
||||
f_source = _sanitize_helper_functions(f_source)
|
||||
function_source = cls.__sanitize_remove_type_hints(f_source) + '\n\n' + function_source
|
||||
|
||||
function_input_artifacts = function_input_artifacts or dict()
|
||||
# verify artifact kwargs:
|
||||
@ -660,3 +674,36 @@ if __name__ == '__main__':
|
||||
task.set_parameters(hyper_parameters)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def __sanitize_remove_type_hints(function_source):
|
||||
# type: (str) -> str
|
||||
try:
|
||||
import ast
|
||||
from ...utilities.lowlevel.astor_unparse import unparse
|
||||
except ImportError:
|
||||
return function_source
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
class TypeHintRemover(ast.NodeTransformer):
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
# remove the return type definition
|
||||
node.returns = None
|
||||
# remove all argument annotations
|
||||
if node.args.args:
|
||||
for arg in node.args.args:
|
||||
arg.annotation = None
|
||||
return node
|
||||
|
||||
# parse the source code into an AST
|
||||
parsed_source = ast.parse(function_source)
|
||||
# remove all type annotations, function return type definitions
|
||||
# and import statements from 'typing'
|
||||
transformed = TypeHintRemover().visit(parsed_source)
|
||||
# convert the AST back to source code
|
||||
return unparse(transformed).lstrip('\n')
|
||||
except Exception:
|
||||
# just in case we failed parsing.
|
||||
return function_source
|
||||
|
@ -130,7 +130,7 @@ class Artifact(object):
|
||||
self._mode = artifact_api_object.mode
|
||||
self._url = artifact_api_object.uri
|
||||
self._hash = artifact_api_object.hash
|
||||
self._timestamp = datetime.fromtimestamp(artifact_api_object.timestamp)
|
||||
self._timestamp = datetime.fromtimestamp(artifact_api_object.timestamp or 0)
|
||||
self._metadata = dict(artifact_api_object.display_data) if artifact_api_object.display_data else {}
|
||||
self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None
|
||||
self._object = None
|
||||
|
900
clearml/utilities/lowlevel/astor_unparse.py
Normal file
900
clearml/utilities/lowlevel/astor_unparse.py
Normal file
@ -0,0 +1,900 @@
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import six
|
||||
from six import StringIO
|
||||
|
||||
# Large float and imaginary literals get turned into infinities in the AST.
|
||||
# We unparse those infinities to INFSTR.
|
||||
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
|
||||
|
||||
|
||||
def interleave(inter, f, seq):
|
||||
"""Call f on each item in seq, calling inter() in between.
|
||||
"""
|
||||
seq = iter(seq)
|
||||
try:
|
||||
f(next(seq))
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
for x in seq:
|
||||
inter()
|
||||
f(x)
|
||||
|
||||
|
||||
class Unparser:
|
||||
"""Methods in this class recursively traverse an AST and
|
||||
output source code for the abstract syntax; original formatting
|
||||
is disregarded. """
|
||||
|
||||
def __init__(self, tree, file=sys.stdout):
|
||||
"""Unparser(tree, file=sys.stdout) -> None.
|
||||
Print the source for tree to file."""
|
||||
self.f = file
|
||||
self.future_imports = []
|
||||
self._indent = 0
|
||||
self.dispatch(tree)
|
||||
print("", file=self.f)
|
||||
self.f.flush()
|
||||
|
||||
def fill(self, text=""):
|
||||
"Indent a piece of text, according to the current indentation level"
|
||||
self.f.write("\n" + " " * self._indent + text)
|
||||
|
||||
def write(self, text):
|
||||
"Append a piece of text to the current line."
|
||||
self.f.write(six.text_type(text))
|
||||
|
||||
def enter(self):
|
||||
"Print ':', and increase the indentation."
|
||||
self.write(":")
|
||||
self._indent += 1
|
||||
|
||||
def leave(self):
|
||||
"Decrease the indentation level."
|
||||
self._indent -= 1
|
||||
|
||||
def dispatch(self, tree):
|
||||
"Dispatcher function, dispatching tree type T to method _T."
|
||||
if isinstance(tree, list):
|
||||
for t in tree:
|
||||
self.dispatch(t)
|
||||
return
|
||||
meth = getattr(self, "_" + tree.__class__.__name__)
|
||||
meth(tree)
|
||||
|
||||
############### Unparsing methods ######################
|
||||
# There should be one method per concrete grammar type #
|
||||
# Constructors should be grouped by sum type. Ideally, #
|
||||
# this would follow the order in the grammar, but #
|
||||
# currently doesn't. #
|
||||
########################################################
|
||||
|
||||
def _Module(self, tree):
|
||||
for stmt in tree.body:
|
||||
self.dispatch(stmt)
|
||||
|
||||
def _Interactive(self, tree):
|
||||
for stmt in tree.body:
|
||||
self.dispatch(stmt)
|
||||
|
||||
def _Expression(self, tree):
|
||||
self.dispatch(tree.body)
|
||||
|
||||
# stmt
|
||||
def _Expr(self, tree):
|
||||
self.fill()
|
||||
self.dispatch(tree.value)
|
||||
|
||||
def _NamedExpr(self, tree):
|
||||
self.write("(")
|
||||
self.dispatch(tree.target)
|
||||
self.write(" := ")
|
||||
self.dispatch(tree.value)
|
||||
self.write(")")
|
||||
|
||||
def _Import(self, t):
|
||||
self.fill("import ")
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.names)
|
||||
|
||||
def _ImportFrom(self, t):
|
||||
# A from __future__ import may affect unparsing, so record it.
|
||||
if t.module and t.module == '__future__':
|
||||
self.future_imports.extend(n.name for n in t.names)
|
||||
|
||||
self.fill("from ")
|
||||
self.write("." * t.level)
|
||||
if t.module:
|
||||
self.write(t.module)
|
||||
self.write(" import ")
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.names)
|
||||
|
||||
def _Assign(self, t):
|
||||
self.fill()
|
||||
for target in t.targets:
|
||||
self.dispatch(target)
|
||||
self.write(" = ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _AugAssign(self, t):
|
||||
self.fill()
|
||||
self.dispatch(t.target)
|
||||
self.write(" " + self.binop[t.op.__class__.__name__] + "= ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _AnnAssign(self, t):
|
||||
self.fill()
|
||||
if not t.simple and isinstance(t.target, ast.Name):
|
||||
self.write('(')
|
||||
self.dispatch(t.target)
|
||||
if not t.simple and isinstance(t.target, ast.Name):
|
||||
self.write(')')
|
||||
self.write(": ")
|
||||
self.dispatch(t.annotation)
|
||||
if t.value:
|
||||
self.write(" = ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Return(self, t):
|
||||
self.fill("return")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Pass(self, t):
|
||||
self.fill("pass")
|
||||
|
||||
def _Break(self, t):
|
||||
self.fill("break")
|
||||
|
||||
def _Continue(self, t):
|
||||
self.fill("continue")
|
||||
|
||||
def _Delete(self, t):
|
||||
self.fill("del ")
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.targets)
|
||||
|
||||
def _Assert(self, t):
|
||||
self.fill("assert ")
|
||||
self.dispatch(t.test)
|
||||
if t.msg:
|
||||
self.write(", ")
|
||||
self.dispatch(t.msg)
|
||||
|
||||
def _Exec(self, t):
|
||||
self.fill("exec ")
|
||||
self.dispatch(t.body)
|
||||
if t.globals:
|
||||
self.write(" in ")
|
||||
self.dispatch(t.globals)
|
||||
if t.locals:
|
||||
self.write(", ")
|
||||
self.dispatch(t.locals)
|
||||
|
||||
def _Print(self, t):
|
||||
self.fill("print ")
|
||||
do_comma = False
|
||||
if t.dest:
|
||||
self.write(">>")
|
||||
self.dispatch(t.dest)
|
||||
do_comma = True
|
||||
for e in t.values:
|
||||
if do_comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
do_comma = True
|
||||
self.dispatch(e)
|
||||
if not t.nl:
|
||||
self.write(",")
|
||||
|
||||
def _Global(self, t):
|
||||
self.fill("global ")
|
||||
interleave(lambda: self.write(", "), self.write, t.names)
|
||||
|
||||
def _Nonlocal(self, t):
|
||||
self.fill("nonlocal ")
|
||||
interleave(lambda: self.write(", "), self.write, t.names)
|
||||
|
||||
def _Await(self, t):
|
||||
self.write("(")
|
||||
self.write("await")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
self.write(")")
|
||||
|
||||
def _Yield(self, t):
|
||||
self.write("(")
|
||||
self.write("yield")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
self.write(")")
|
||||
|
||||
def _YieldFrom(self, t):
|
||||
self.write("(")
|
||||
self.write("yield from")
|
||||
if t.value:
|
||||
self.write(" ")
|
||||
self.dispatch(t.value)
|
||||
self.write(")")
|
||||
|
||||
def _Raise(self, t):
|
||||
self.fill("raise")
|
||||
if six.PY3:
|
||||
if not t.exc:
|
||||
assert not t.cause
|
||||
return
|
||||
self.write(" ")
|
||||
self.dispatch(t.exc)
|
||||
if t.cause:
|
||||
self.write(" from ")
|
||||
self.dispatch(t.cause)
|
||||
else:
|
||||
self.write(" ")
|
||||
if t.type:
|
||||
self.dispatch(t.type)
|
||||
if t.inst:
|
||||
self.write(", ")
|
||||
self.dispatch(t.inst)
|
||||
if t.tback:
|
||||
self.write(", ")
|
||||
self.dispatch(t.tback)
|
||||
|
||||
def _Try(self, t):
|
||||
self.fill("try")
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
for ex in t.handlers:
|
||||
self.dispatch(ex)
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
if t.finalbody:
|
||||
self.fill("finally")
|
||||
self.enter()
|
||||
self.dispatch(t.finalbody)
|
||||
self.leave()
|
||||
|
||||
def _TryExcept(self, t):
|
||||
self.fill("try")
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
for ex in t.handlers:
|
||||
self.dispatch(ex)
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _TryFinally(self, t):
|
||||
if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept):
|
||||
# try-except-finally
|
||||
self.dispatch(t.body)
|
||||
else:
|
||||
self.fill("try")
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
self.fill("finally")
|
||||
self.enter()
|
||||
self.dispatch(t.finalbody)
|
||||
self.leave()
|
||||
|
||||
def _ExceptHandler(self, t):
|
||||
self.fill("except")
|
||||
if t.type:
|
||||
self.write(" ")
|
||||
self.dispatch(t.type)
|
||||
if t.name:
|
||||
self.write(" as ")
|
||||
if six.PY3:
|
||||
self.write(t.name)
|
||||
else:
|
||||
self.dispatch(t.name)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _ClassDef(self, t):
|
||||
self.write("\n")
|
||||
for deco in t.decorator_list:
|
||||
self.fill("@")
|
||||
self.dispatch(deco)
|
||||
self.fill("class " + t.name)
|
||||
if six.PY3:
|
||||
self.write("(")
|
||||
comma = False
|
||||
for e in t.bases:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
for e in t.keywords:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
if sys.version_info[:2] < (3, 5):
|
||||
if t.starargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("*")
|
||||
self.dispatch(t.starargs)
|
||||
if t.kwargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("**")
|
||||
self.dispatch(t.kwargs)
|
||||
self.write(")")
|
||||
elif t.bases:
|
||||
self.write("(")
|
||||
for a in t.bases:
|
||||
self.dispatch(a)
|
||||
self.write(", ")
|
||||
self.write(")")
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _FunctionDef(self, t):
|
||||
self.__FunctionDef_helper(t, "def")
|
||||
|
||||
def _AsyncFunctionDef(self, t):
|
||||
self.__FunctionDef_helper(t, "async def")
|
||||
|
||||
def __FunctionDef_helper(self, t, fill_suffix):
|
||||
self.write("\n")
|
||||
for deco in t.decorator_list:
|
||||
self.fill("@")
|
||||
self.dispatch(deco)
|
||||
def_str = fill_suffix + " " + t.name + "("
|
||||
self.fill(def_str)
|
||||
self.dispatch(t.args)
|
||||
self.write(")")
|
||||
if getattr(t, "returns", False):
|
||||
self.write(" -> ")
|
||||
self.dispatch(t.returns)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _For(self, t):
|
||||
self.__For_helper("for ", t)
|
||||
|
||||
def _AsyncFor(self, t):
|
||||
self.__For_helper("async for ", t)
|
||||
|
||||
def __For_helper(self, fill, t):
|
||||
self.fill(fill)
|
||||
self.dispatch(t.target)
|
||||
self.write(" in ")
|
||||
self.dispatch(t.iter)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _If(self, t):
|
||||
self.fill("if ")
|
||||
self.dispatch(t.test)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
# collapse nested ifs into equivalent elifs.
|
||||
while (t.orelse and len(t.orelse) == 1 and
|
||||
isinstance(t.orelse[0], ast.If)):
|
||||
t = t.orelse[0]
|
||||
self.fill("elif ")
|
||||
self.dispatch(t.test)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
# final else
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _While(self, t):
|
||||
self.fill("while ")
|
||||
self.dispatch(t.test)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
if t.orelse:
|
||||
self.fill("else")
|
||||
self.enter()
|
||||
self.dispatch(t.orelse)
|
||||
self.leave()
|
||||
|
||||
def _generic_With(self, t, async_=False):
|
||||
self.fill("async with " if async_ else "with ")
|
||||
if hasattr(t, 'items'):
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.items)
|
||||
else:
|
||||
self.dispatch(t.context_expr)
|
||||
if t.optional_vars:
|
||||
self.write(" as ")
|
||||
self.dispatch(t.optional_vars)
|
||||
self.enter()
|
||||
self.dispatch(t.body)
|
||||
self.leave()
|
||||
|
||||
def _With(self, t):
|
||||
self._generic_With(t)
|
||||
|
||||
def _AsyncWith(self, t):
|
||||
self._generic_With(t, async_=True)
|
||||
|
||||
# expr
|
||||
def _Bytes(self, t):
|
||||
self.write(repr(t.s))
|
||||
|
||||
def _Str(self, tree):
|
||||
if six.PY3:
|
||||
self.write(repr(tree.s))
|
||||
else:
|
||||
# if from __future__ import unicode_literals is in effect,
|
||||
# then we want to output string literals using a 'b' prefix
|
||||
# and unicode literals with no prefix.
|
||||
if "unicode_literals" not in self.future_imports:
|
||||
self.write(repr(tree.s))
|
||||
elif isinstance(tree.s, str):
|
||||
self.write("b" + repr(tree.s))
|
||||
elif isinstance(tree.s, unicode): # noqa
|
||||
self.write(repr(tree.s).lstrip("u"))
|
||||
else:
|
||||
assert False, "shouldn't get here"
|
||||
|
||||
def _JoinedStr(self, t):
|
||||
# JoinedStr(expr* values)
|
||||
self.write("f")
|
||||
string = StringIO()
|
||||
self._fstring_JoinedStr(t, string.write)
|
||||
# Deviation from `unparse.py`: Try to find an unused quote.
|
||||
# This change is made to handle _very_ complex f-strings.
|
||||
v = string.getvalue()
|
||||
if '\n' in v or '\r' in v:
|
||||
quote_types = ["'''", '"""']
|
||||
else:
|
||||
quote_types = ["'", '"', '"""', "'''"]
|
||||
for quote_type in quote_types:
|
||||
if quote_type not in v:
|
||||
v = "{quote_type}{v}{quote_type}".format(quote_type=quote_type, v=v)
|
||||
break
|
||||
else:
|
||||
v = repr(v)
|
||||
self.write(v)
|
||||
|
||||
def _FormattedValue(self, t):
|
||||
# FormattedValue(expr value, int? conversion, expr? format_spec)
|
||||
self.write("f")
|
||||
string = StringIO()
|
||||
self._fstring_JoinedStr(t, string.write)
|
||||
self.write(repr(string.getvalue()))
|
||||
|
||||
def _fstring_JoinedStr(self, t, write):
|
||||
for value in t.values:
|
||||
meth = getattr(self, "_fstring_" + type(value).__name__)
|
||||
meth(value, write)
|
||||
|
||||
def _fstring_Str(self, t, write):
|
||||
value = t.s.replace("{", "{{").replace("}", "}}")
|
||||
write(value)
|
||||
|
||||
def _fstring_Constant(self, t, write):
|
||||
assert isinstance(t.value, str)
|
||||
value = t.value.replace("{", "{{").replace("}", "}}")
|
||||
write(value)
|
||||
|
||||
def _fstring_FormattedValue(self, t, write):
|
||||
write("{")
|
||||
expr = StringIO()
|
||||
Unparser(t.value, expr)
|
||||
expr = expr.getvalue().rstrip("\n")
|
||||
if expr.startswith("{"):
|
||||
write(" ") # Separate pair of opening brackets as "{ {"
|
||||
write(expr)
|
||||
if t.conversion != -1:
|
||||
conversion = chr(t.conversion)
|
||||
assert conversion in "sra"
|
||||
write("!{conversion}".format(conversion=conversion))
|
||||
if t.format_spec:
|
||||
write(":")
|
||||
meth = getattr(self, "_fstring_" + type(t.format_spec).__name__)
|
||||
meth(t.format_spec, write)
|
||||
write("}")
|
||||
|
||||
def _Name(self, t):
|
||||
self.write(t.id)
|
||||
|
||||
def _NameConstant(self, t):
|
||||
self.write(repr(t.value))
|
||||
|
||||
def _Repr(self, t):
|
||||
self.write("`")
|
||||
self.dispatch(t.value)
|
||||
self.write("`")
|
||||
|
||||
def _write_constant(self, value):
|
||||
if isinstance(value, (float, complex)):
|
||||
# Substitute overflowing decimal literal for AST infinities.
|
||||
self.write(repr(value).replace("inf", INFSTR))
|
||||
else:
|
||||
self.write(repr(value))
|
||||
|
||||
def _Constant(self, t):
|
||||
value = t.value
|
||||
if isinstance(value, tuple):
|
||||
self.write("(")
|
||||
if len(value) == 1:
|
||||
self._write_constant(value[0])
|
||||
self.write(",")
|
||||
else:
|
||||
interleave(lambda: self.write(", "), self._write_constant, value)
|
||||
self.write(")")
|
||||
elif value is Ellipsis: # instead of `...` for Py2 compatibility
|
||||
self.write("...")
|
||||
else:
|
||||
if t.kind == "u":
|
||||
self.write("u")
|
||||
self._write_constant(t.value)
|
||||
|
||||
def _Num(self, t):
|
||||
repr_n = repr(t.n)
|
||||
if six.PY3:
|
||||
self.write(repr_n.replace("inf", INFSTR))
|
||||
else:
|
||||
# Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2.
|
||||
if repr_n.startswith("-"):
|
||||
self.write("(")
|
||||
if "inf" in repr_n and repr_n.endswith("*j"):
|
||||
repr_n = repr_n.replace("*j", "j")
|
||||
# Substitute overflowing decimal literal for AST infinities.
|
||||
self.write(repr_n.replace("inf", INFSTR))
|
||||
if repr_n.startswith("-"):
|
||||
self.write(")")
|
||||
|
||||
def _List(self, t):
|
||||
self.write("[")
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.elts)
|
||||
self.write("]")
|
||||
|
||||
def _ListComp(self, t):
|
||||
self.write("[")
|
||||
self.dispatch(t.elt)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write("]")
|
||||
|
||||
def _GeneratorExp(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.elt)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write(")")
|
||||
|
||||
def _SetComp(self, t):
|
||||
self.write("{")
|
||||
self.dispatch(t.elt)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write("}")
|
||||
|
||||
def _DictComp(self, t):
|
||||
self.write("{")
|
||||
self.dispatch(t.key)
|
||||
self.write(": ")
|
||||
self.dispatch(t.value)
|
||||
for gen in t.generators:
|
||||
self.dispatch(gen)
|
||||
self.write("}")
|
||||
|
||||
def _comprehension(self, t):
|
||||
if getattr(t, 'is_async', False):
|
||||
self.write(" async for ")
|
||||
else:
|
||||
self.write(" for ")
|
||||
self.dispatch(t.target)
|
||||
self.write(" in ")
|
||||
self.dispatch(t.iter)
|
||||
for if_clause in t.ifs:
|
||||
self.write(" if ")
|
||||
self.dispatch(if_clause)
|
||||
|
||||
def _IfExp(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.body)
|
||||
self.write(" if ")
|
||||
self.dispatch(t.test)
|
||||
self.write(" else ")
|
||||
self.dispatch(t.orelse)
|
||||
self.write(")")
|
||||
|
||||
def _Set(self, t):
|
||||
assert (t.elts) # should be at least one element
|
||||
self.write("{")
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.elts)
|
||||
self.write("}")
|
||||
|
||||
def _Dict(self, t):
|
||||
self.write("{")
|
||||
|
||||
def write_key_value_pair(k, v):
|
||||
self.dispatch(k)
|
||||
self.write(": ")
|
||||
self.dispatch(v)
|
||||
|
||||
def write_item(item):
|
||||
k, v = item
|
||||
if k is None:
|
||||
# for dictionary unpacking operator in dicts {**{'y': 2}}
|
||||
# see PEP 448 for details
|
||||
self.write("**")
|
||||
self.dispatch(v)
|
||||
else:
|
||||
write_key_value_pair(k, v)
|
||||
|
||||
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
|
||||
self.write("}")
|
||||
|
||||
def _Tuple(self, t):
|
||||
self.write("(")
|
||||
if len(t.elts) == 1:
|
||||
elt = t.elts[0]
|
||||
self.dispatch(elt)
|
||||
self.write(",")
|
||||
else:
|
||||
interleave(lambda: self.write(", "), self.dispatch, t.elts)
|
||||
self.write(")")
|
||||
|
||||
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
|
||||
|
||||
def _UnaryOp(self, t):
|
||||
self.write("(")
|
||||
self.write(self.unop[t.op.__class__.__name__])
|
||||
self.write(" ")
|
||||
if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num):
|
||||
# If we're applying unary minus to a number, parenthesize the number.
|
||||
# This is necessary: -2147483648 is different from -(2147483648) on
|
||||
# a 32-bit machine (the first is an int, the second a long), and
|
||||
# -7j is different from -(7j). (The first has real part 0.0, the second
|
||||
# has real part -0.0.)
|
||||
self.write("(")
|
||||
self.dispatch(t.operand)
|
||||
self.write(")")
|
||||
else:
|
||||
self.dispatch(t.operand)
|
||||
self.write(")")
|
||||
|
||||
binop = {"Add": "+", "Sub": "-", "Mult": "*", "MatMult": "@", "Div": "/", "Mod": "%",
|
||||
"LShift": "<<", "RShift": ">>", "BitOr": "|", "BitXor": "^", "BitAnd": "&",
|
||||
"FloorDiv": "//", "Pow": "**"}
|
||||
|
||||
def _BinOp(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.left)
|
||||
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
|
||||
self.dispatch(t.right)
|
||||
self.write(")")
|
||||
|
||||
cmpops = {"Eq": "==", "NotEq": "!=", "Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=",
|
||||
"Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in"}
|
||||
|
||||
def _Compare(self, t):
|
||||
self.write("(")
|
||||
self.dispatch(t.left)
|
||||
for o, e in zip(t.ops, t.comparators):
|
||||
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
|
||||
self.dispatch(e)
|
||||
self.write(")")
|
||||
|
||||
boolops = {ast.And: 'and', ast.Or: 'or'}
|
||||
|
||||
def _BoolOp(self, t):
|
||||
self.write("(")
|
||||
s = " %s " % self.boolops[t.op.__class__]
|
||||
interleave(lambda: self.write(s), self.dispatch, t.values)
|
||||
self.write(")")
|
||||
|
||||
def _Attribute(self, t):
|
||||
self.dispatch(t.value)
|
||||
# Special case: 3.__abs__() is a syntax error, so if t.value
|
||||
# is an integer literal then we need to either parenthesize
|
||||
# it or add an extra space to get 3 .__abs__().
|
||||
if isinstance(t.value, getattr(ast, 'Constant', getattr(ast, 'Num', None))) and isinstance(t.value.n, int):
|
||||
self.write(" ")
|
||||
self.write(".")
|
||||
self.write(t.attr)
|
||||
|
||||
def _Call(self, t):
|
||||
self.dispatch(t.func)
|
||||
self.write("(")
|
||||
comma = False
|
||||
for e in t.args:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
for e in t.keywords:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.dispatch(e)
|
||||
if sys.version_info[:2] < (3, 5):
|
||||
if t.starargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("*")
|
||||
self.dispatch(t.starargs)
|
||||
if t.kwargs:
|
||||
if comma:
|
||||
self.write(", ")
|
||||
else:
|
||||
comma = True
|
||||
self.write("**")
|
||||
self.dispatch(t.kwargs)
|
||||
self.write(")")
|
||||
|
||||
def _Subscript(self, t):
|
||||
self.dispatch(t.value)
|
||||
self.write("[")
|
||||
self.dispatch(t.slice)
|
||||
self.write("]")
|
||||
|
||||
def _Starred(self, t):
|
||||
self.write("*")
|
||||
self.dispatch(t.value)
|
||||
|
||||
# slice
|
||||
def _Ellipsis(self, t):
|
||||
self.write("...")
|
||||
|
||||
def _Index(self, t):
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Slice(self, t):
|
||||
if t.lower:
|
||||
self.dispatch(t.lower)
|
||||
self.write(":")
|
||||
if t.upper:
|
||||
self.dispatch(t.upper)
|
||||
if t.step:
|
||||
self.write(":")
|
||||
self.dispatch(t.step)
|
||||
|
||||
def _ExtSlice(self, t):
|
||||
interleave(lambda: self.write(', '), self.dispatch, t.dims)
|
||||
|
||||
# argument
|
||||
def _arg(self, t):
|
||||
self.write(t.arg)
|
||||
if t.annotation:
|
||||
self.write(": ")
|
||||
self.dispatch(t.annotation)
|
||||
|
||||
# others
|
||||
def _arguments(self, t):
|
||||
first = True
|
||||
# normal arguments
|
||||
all_args = getattr(t, 'posonlyargs', []) + t.args
|
||||
defaults = [None] * (len(all_args) - len(t.defaults)) + t.defaults
|
||||
for index, elements in enumerate(zip(all_args, defaults), 1):
|
||||
a, d = elements
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.dispatch(a)
|
||||
if d:
|
||||
self.write("=")
|
||||
self.dispatch(d)
|
||||
if index == len(getattr(t, 'posonlyargs', ())):
|
||||
self.write(", /")
|
||||
|
||||
# varargs, or bare '*' if no varargs but keyword-only arguments present
|
||||
if t.vararg or getattr(t, "kwonlyargs", False):
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.write("*")
|
||||
if t.vararg:
|
||||
if hasattr(t.vararg, 'arg'):
|
||||
self.write(t.vararg.arg)
|
||||
if t.vararg.annotation:
|
||||
self.write(": ")
|
||||
self.dispatch(t.vararg.annotation)
|
||||
else:
|
||||
self.write(t.vararg)
|
||||
if getattr(t, 'varargannotation', None):
|
||||
self.write(": ")
|
||||
self.dispatch(t.varargannotation)
|
||||
|
||||
# keyword-only arguments
|
||||
if getattr(t, "kwonlyargs", False):
|
||||
for a, d in zip(t.kwonlyargs, t.kw_defaults):
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
self.dispatch(a),
|
||||
if d:
|
||||
self.write("=")
|
||||
self.dispatch(d)
|
||||
|
||||
# kwargs
|
||||
if t.kwarg:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.write(", ")
|
||||
if hasattr(t.kwarg, 'arg'):
|
||||
self.write("**" + t.kwarg.arg)
|
||||
if t.kwarg.annotation:
|
||||
self.write(": ")
|
||||
self.dispatch(t.kwarg.annotation)
|
||||
else:
|
||||
self.write("**" + t.kwarg)
|
||||
if getattr(t, 'kwargannotation', None):
|
||||
self.write(": ")
|
||||
self.dispatch(t.kwargannotation)
|
||||
|
||||
def _keyword(self, t):
|
||||
if t.arg is None:
|
||||
# starting from Python 3.5 this denotes a kwargs part of the invocation
|
||||
self.write("**")
|
||||
else:
|
||||
self.write(t.arg)
|
||||
self.write("=")
|
||||
self.dispatch(t.value)
|
||||
|
||||
def _Lambda(self, t):
|
||||
self.write("(")
|
||||
self.write("lambda ")
|
||||
self.dispatch(t.args)
|
||||
self.write(": ")
|
||||
self.dispatch(t.body)
|
||||
self.write(")")
|
||||
|
||||
def _alias(self, t):
|
||||
self.write(t.name)
|
||||
if t.asname:
|
||||
self.write(" as " + t.asname)
|
||||
|
||||
def _withitem(self, t):
|
||||
self.dispatch(t.context_expr)
|
||||
if t.optional_vars:
|
||||
self.write(" as ")
|
||||
self.dispatch(t.optional_vars)
|
||||
|
||||
|
||||
def unparse(tree):
|
||||
v = six.moves.cStringIO()
|
||||
Unparser(tree, file=v)
|
||||
return v.getvalue()
|
Loading…
Reference in New Issue
Block a user