import re
from copy import copy
from datetime import datetime
from logging import getLogger
from threading import Thread, Event
from time import time

from plotly import graph_objects as go
from plotly.subplots import make_subplots

from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any

from trains import Task
from trains.automation import TrainsJob
from trains.model import BaseModel


class PipelineController(object):
    """
    Pipeline controller.
    Pipeline is a DAG of base tasks, each task will be cloned (arguments changed as required) executed and monitored
    The pipeline process (task) itself can be executed manually or by the trains-agent services queue.
    Notice: The pipeline controller lives as long as the pipeline itself is being executed.
    """
    _tag = 'pipeline'
    _step_pattern = r"\${[^}]*}"
    _config_section = 'Pipeline'

    @attrs
    class Node(object):
        name = attrib(type=str)
        base_task_id = attrib(type=str)
        queue = attrib(type=str, default=None)
        parents = attrib(type=list, default=[])
        timeout = attrib(type=float, default=None)
        parameters = attrib(type=dict, default={})
        executed = attrib(type=str, default=None)
        job = attrib(type=TrainsJob, default=None)

    def __init__(
            self,
            pool_frequency=0.2,  # type: float
            default_execution_queue=None,  # type: Optional[str]
            pipeline_time_limit=None,  # type: Optional[float]
            auto_connect_task=True,  # type: bool
            always_create_task=False,  # type: bool
            add_pipeline_tags=False,  # type: bool
    ):
        # type: (...) -> ()
        """
        Create a new pipeline controller. The newly created object will launch and monitor the new experiments.

        :param float pool_frequency: The pooling frequency (in minutes) for monitoring experiments / states.
        :param str default_execution_queue: The execution queue to use if no execution queue is provided
        :param float pipeline_time_limit: The maximum time (minutes) for the entire pipeline process. The
            default is ``None``, indicating no time limit.
        :param bool auto_connect_task: Store pipeline arguments and configuration in the Task
            - ``True`` - The pipeline argument and configuration will be stored in the Task. All arguments will
              be under the hyper-parameter section as ``opt/<arg>``, and the hyper_parameters will stored in the
              Task ``connect_configuration`` (see artifacts/hyper-parameter).

            - ``False`` - Do not store with Task.
        :param bool always_create_task: Always create a new Task
            - ``True`` - No current Task initialized. Create a new task named ``Pipeline`` in the ``base_task_id``
              project.

            - ``False`` - Use the :py:meth:`task.Task.current_task` (if exists) to report statistics.
        :param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
            steps (Tasks) created by this pipeline.
        """
        self._nodes = {}
        self._running_nodes = []
        self._start_time = None
        self._pipeline_time_limit = pipeline_time_limit * 60. if pipeline_time_limit else None
        self._default_execution_queue = default_execution_queue
        self._pool_frequency = pool_frequency * 60.
        self._thread = None
        self._stop_event = None
        self._experiment_created_cb = None
        self._add_pipeline_tags = add_pipeline_tags
        self._task = Task.current_task()
        self._step_ref_pattern = re.compile(self._step_pattern)
        if not self._task and always_create_task:
            self._task = Task.init(
                project_name='Pipelines',
                task_name='Pipeline {}'.format(datetime.now()),
                task_type=Task.TaskTypes.controller,
            )

        # make sure all the created tasks are our children, as we are creating them
        if self._task:
            self._task.add_tags([self._tag])
            self._auto_connect_task = auto_connect_task

    def add_step(
            self,
            name,  # type: str
            base_task_id=None,  # type: Optional[str]
            parents=None,  # type: Optional[Sequence[str]]
            parameter_override=None,  # type: Optional[Mapping[str, Any]]
            execution_queue=None,  # type: Optional[str]
            time_limit=None,  # type: Optional[float]
            base_task_project=None,  # type: Optional[str]
            base_task_name=None,  # type: Optional[str]
    ):
        # type: (...) -> bool
        """
        Add a step to the pipeline execution DAG.
        Each step must have a unique name (this name will later be used to address the step)

        :param str name: Unique of the step. For example `stage1`
        :param str base_task_id: The Task ID to use for the step. Each time the step is executed,
            the base Task is cloned, then the cloned task will be sent for execution.
        :param list parents: Optional list of parent nodes in the DAG.
            The current step in the pipeline will be sent for execution only after all the parent nodes
            have been executed successfully.
        :param dict parameter_override: Optional parameter overriding dictionary.
            The dict values can reference a previously executed step using the following form '${step_name}'
            Examples:
                Artifact access
                    parameter_override={'Args/input_file': '${stage1.artifacts.mydata.url}' }
                Model access (last model used)
                    parameter_override={'Args/input_file': '${stage1.models.output.-1.url}' }
                Parameter access
                    parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' }
                Task ID
                    parameter_override={'Args/input_file': '${stage3.id}' }
        :param str execution_queue: Optional, the queue to use for executing this specific step.
            If not provided, the task will be sent to the default execution queue, as defined on the class
        :param float time_limit: Default None, no time limit.
            Step execution time limit, if exceeded the Task is aborted and the pipeline is stopped and marked failed.
        :param str base_task_project: If base_task_id is not given,
            use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
        :param str base_task_name: If base_task_id is not given,
            use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
        :return: True if successful
        """
        # when running remotely do nothing, we will deserialize ourselves when we start
        if self._task and not self._task.running_locally() and self._task.is_main_task():
            return True

        if name in self._nodes:
            raise ValueError('Node named \'{}\' already exists in the pipeline dag'.format(name))

        if not base_task_id:
            if not base_task_project or not base_task_name:
                raise ValueError('Either base_task_id or base_task_project/base_task_name must be provided')
            base_task = Task.get_task(project_name=base_task_project, task_name=base_task_name)
            if not base_task:
                raise ValueError('Could not find base_task_project={} base_task_name={}'.format(
                    base_task_project, base_task_name))
            base_task_id = base_task.id

        self._nodes[name] = self.Node(
            name=name, base_task_id=base_task_id, parents=parents or [],
            queue=execution_queue, timeout=time_limit,
            parameters=parameter_override or {})
        return True

    def start(self, run_remotely=False, step_task_created_callback=None):
        # type: (bool, Optional[Callable[[PipelineController.Node, dict], None]]) -> bool
        """
        Start the pipeline controller.
        If the calling process is stopped, then the controller stops as well.

        :param bool run_remotely: (default False), If True stop the current process and continue execution
            on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'
        :param Callable step_task_created_callback: Callback function, called when a step (Task) is created
            and before it is sent for execution.

            .. code-block:: py

                def step_created_callback(
                    node,                 # type: PipelineController.Node,
                    parameters,           # type: dict
                ):
                    pass

        :return: True, if the controller started. False, if the controller did not start.

        """
        if self._thread:
            return True

        # serialize pipeline state
        pipeline_dag = self._serialize()
        self._task.connect_configuration(pipeline_dag, name=self._config_section)
        params = {'continue_pipeline': False,
                  'default_queue': self._default_execution_queue,
                  'add_pipeline_tags': self._add_pipeline_tags,
                  }
        self._task.connect(params, name=self._config_section)
        # deserialize back pipeline state
        if not params['continue_pipeline']:
            for k in pipeline_dag:
                pipeline_dag[k]['executed'] = None
        self._default_execution_queue = params['default_queue']
        self._add_pipeline_tags = params['add_pipeline_tags']
        self._deserialize(pipeline_dag)

        if not self._verify():
            raise ValueError("Failed verifying pipeline execution graph, "
                             "it has either inaccessible nodes, or contains cycles")

        self._update_execution_plot()

        if run_remotely:
            self._task.execute_remotely(queue_name='services')
            # we will not get here if we are not running remotely

        self._start_time = time()
        self._stop_event = Event()
        self._experiment_created_cb = step_task_created_callback
        self._thread = Thread(target=self._daemon)
        self._thread.daemon = True
        self._thread.start()
        return True

    def stop(self, timeout=None):
        # type: (Optional[float]) -> ()
        """
        Stop the pipeline controller and the optimization thread.

        :param float timeout: Wait timeout for the optimization thread to exit (minutes).
            The default is ``None``, indicating do not wait terminate immediately.
        """
        pass

    def wait(self, timeout=None):
        # type: (Optional[float]) -> bool
        """
        Wait for the pipeline to finish.

        .. note::
            This method does not stop the pipeline. Call :meth:`stop` to terminate the pipeline.

        :param float timeout: The timeout to wait for the pipeline to complete (minutes).
            If ``None``, then wait until we reached the timeout, or pipeline completed.

        :return: True, if the pipeline finished. False, if the pipeline timed out.

        """
        if not self.is_running():
            return True

        if timeout is not None:
            timeout *= 60.

        _thread = self._thread

        _thread.join(timeout=timeout)
        if _thread.is_alive():
            return False

        return True

    def is_running(self):
        # type: () -> bool
        """
        return True if the pipeline controller is running.

        :return: A boolean indicating whether the pipeline controller is active (still running) or stopped.
        """
        return self._thread is not None

    def elapsed(self):
        # type: () -> float
        """
        Return minutes elapsed from controller stating time stamp.

        :return: The minutes from controller start time. A negative value means the process has not started yet.
        """
        if self._start_time is None:
            return -1.0
        return (time() - self._start_time) / 60.

    def get_pipeline_dag(self):
        # type: () -> Mapping[str, PipelineController.Node]
        """
        Return the pipeline execution graph, each node in the DAG is PipelineController.Node object.
        Graph itself is a dictionary of Nodes (key based on the Node name),
        each node holds links to its parent Nodes (identified by their unique names)

        :return: execution tree, as a nested dictionary
        Example:
            {
                'stage1' : Node() {
                    name: 'stage1'
                    job: TrainsJob
                    ...
                },
            }
        """
        return self._nodes

    def get_processed_nodes(self):
        # type: () -> Sequence[PipelineController.Node]
        """
        Return the a list of the processed pipeline nodes, each entry in the list is PipelineController.Node object.

        :return: executed (excluding currently executing) nodes list
        """
        return {k: n for k, n in self._nodes.items() if n.executed}

    def get_running_nodes(self):
        # type: () -> Sequence[PipelineController.Node]
        """
        Return the a list of the currently running pipeline nodes,
        each entry in the list is PipelineController.Node object.

        :return: Currently running nodes list
        """
        return {k: n for k, n in self._nodes.items() if k in self._running_nodes}

    def _serialize(self):
        # type: () -> dict
        """
        Store the definition of the pipeline DAG into a dictionary.
        This dictionary will be used to store the DAG as a configuration on the Task
        :return:
        """
        dag = {name: dict((k, v) for k, v in node.__dict__.items() if k not in ('job', 'name'))
               for name, node in self._nodes.items()}

        return dag

    def _deserialize(self, dag_dict):
        # type: (dict) -> ()
        """
        Restore the DAG from a dictionary.
        This will be used to create the DAG from the dict stored on the Task, when running remotely.
        :return:
        """
        self._nodes = {k: self.Node(name=k, **v) for k, v in dag_dict.items()}

    def _verify(self):
        # type: () -> bool
        """
        Verify the DAG, (i.e. no cycles and no missing parents)
        On error raise ValueError with verification details

        :return: return True iff DAG has no errors
        """
        # verify nodes
        for node in self._nodes.values():
            # raise value error if not verified
            self._verify_node(node)

        # check the dag itself
        if not self._verify_dag():
            return False

        return True

    def _verify_node(self, node):
        # type: (Node) -> bool
        """
        Raise ValueError on verification errors

        :return: Return True iff the specific node is verified
        """
        if not node.base_task_id:
            raise ValueError("Node '{}', base_task_id is empty".format(node.name))

        if not self._default_execution_queue and not node.queue:
            raise ValueError("Node '{}' missing execution queue, "
                             "no default queue defined and no specific node queue defined".format(node.name))

        task = Task.get_task(task_id=node.base_task_id)
        if not task:
            raise ValueError("Node '{}', base_task_id={} is invalid".format(node.name, node.base_task_id))

        pattern = self._step_ref_pattern

        for v in node.parameters.values():
            if isinstance(v, str):
                for g in pattern.findall(v):
                    self.__verify_step_reference(node, g)

        return True

    def _verify_dag(self):
        # type: () -> bool
        """
        :return: True iff the pipeline dag is fully accessible and contains no cycles
        """
        visited = set()
        prev_visited = None
        while prev_visited != visited:
            prev_visited = copy(visited)
            for k, node in self._nodes.items():
                if k in visited:
                    continue
                if not all(p in visited for p in node.parents or []):
                    continue
                visited.add(k)
        # return False if we did not cover all the nodes
        return not bool(set(self._nodes.keys()) - visited)

    def _launch_node(self, node):
        # type: (Node) -> ()
        """
        Launch a single node (create and enqueue a TrainsJob)

        :param node: Node to launch
        :return: Return True if a new job was launched
        """
        if node.job or node.executed:
            return False

        updated_hyper_parameters = {}
        for k, v in node.parameters.items():
            updated_hyper_parameters[k] = self._parse_step_ref(v)

        node.job = TrainsJob(
            base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters,
            tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None,
            parent=self._task.id if self._task else None)
        if self._experiment_created_cb:
            self._experiment_created_cb(node, updated_hyper_parameters)
        node.job.launch(queue_name=node.queue or self._default_execution_queue)
        return True

    def _update_execution_plot(self):
        # type: () -> ()
        """
        Update sankey diagram of the current pipeline
        """
        sankey_node = dict(
            label=[],
            color=[],
            hovertemplate='%{label}<extra></extra>',
            # customdata=[],
            # hovertemplate='%{label}<br />Hyper-Parameters:<br />%{customdata}<extra></extra>',
        )
        sankey_link = dict(
            source=[],
            target=[],
            value=[],
            hovertemplate='%{target.label}<extra></extra>',
        )
        visited = []
        node_params = []
        nodes = list(self._nodes.values())
        while nodes:
            next_nodes = []
            for node in nodes:
                if not all(p in visited for p in node.parents or []):
                    next_nodes.append(node)
                    continue
                visited.append(node.name)
                idx = len(visited) - 1
                parents = [visited.index(p) for p in node.parents or []]
                node_params.append(node.job.task_parameter_override if node.job else node.parameters) or {}
                # sankey_node['label'].append(node.name)
                # sankey_node['customdata'].append(
                #     '<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
                sankey_node['label'].append(
                    '{}<br />'.format(node.name) +
                    '<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
                sankey_node['color'].append(
                    ("blue" if not node.job or not node.job.is_failed() else "red")
                    if node.executed else ("green" if node.job else "lightsteelblue"))

                for p in parents:
                    sankey_link['source'].append(p)
                    sankey_link['target'].append(idx)
                    sankey_link['value'].append(1)

            nodes = next_nodes

        # make sure we have no independent (unconnected) nodes
        for i in [n for n in range(len(visited)) if n not in sankey_link['source'] and n not in sankey_link['target']]:
            sankey_link['source'].append(i)
            sankey_link['target'].append(i)
            sankey_link['value'].append(0.1)

        fig = make_subplots(
            rows=2, cols=1,
            shared_xaxes=True,
            vertical_spacing=0.03,
            specs=[[{"type": "table"}],
                   [{"type": "sankey"}], ]
        )
        # noinspection PyUnresolvedReferences
        fig.add_trace(
            go.Sankey(
                node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1)
            ),
            row=1, col=1
        )
        # noinspection PyUnresolvedReferences
        fig.add_trace(
            go.Table(
                header=dict(
                    values=["Pipeline Step", "Task ID", "Parameters"],
                    align="left",
                ),
                cells=dict(
                    values=[visited,
                            [self._nodes[v].executed or (self._nodes[v].job.task_id() if self._nodes[v].job else '')
                             for v in visited],
                            [str(p) for p in node_params]],
                    align="left")
            ),
            row=2, col=1
        )

        # fig = go.Figure(data=[go.Sankey(
        #     node=sankey_node, link=sankey_link, textfont=dict(color='rgba(0,0,0,0)', size=1))],)
        self._task.get_logger().report_plotly(
            title='Pipeline', series='execution flow', iteration=0, figure=fig)

    def _force_task_configuration_update(self):
        pipeline_dag = self._serialize()
        # noinspection PyProtectedMember
        self._task._set_configuration(
            name=self._config_section, config_type='dictionary', config_dict=pipeline_dag)

    def _daemon(self):
        # type: () -> ()
        """
        The main pipeline execution loop. This loop is executed on its own dedicated thread.
        :return:
        """
        pooling_counter = 0

        while self._stop_event:
            # stop request
            if pooling_counter and self._stop_event.wait(self._pool_frequency):
                break

            pooling_counter += 1

            # check the pipeline time limit
            if self._pipeline_time_limit and (time() - self._start_time) > self._pipeline_time_limit:
                break

            # check the state of all current jobs
            # if no a job ended, continue
            completed_jobs = []
            for j in self._running_nodes:
                node = self._nodes[j]
                if not node.job:
                    continue
                if node.job.is_stopped():
                    completed_jobs.append(j)
                    node.executed = node.job.task_id() if not node.job.is_failed() else False
                elif node.timeout:
                    started = node.job.task.data.started
                    if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout:
                        node.job.abort()
                        completed_jobs.append(j)
                        node.executed = node.job.task_id()

            # update running jobs
            self._running_nodes = [j for j in self._running_nodes if j not in completed_jobs]

            # nothing changed, we can sleep
            if not completed_jobs and self._running_nodes:
                continue

            # Pull the next jobs in the pipeline, based on the completed list
            next_nodes = []
            for node in self._nodes.values():
                # check if already processed.
                if node.job or node.executed:
                    continue
                completed_parents = [bool(p in self._nodes and self._nodes[p].executed) for p in node.parents or []]
                if all(completed_parents):
                    next_nodes.append(node.name)

            # update the execution graph
            for name in next_nodes:
                if self._launch_node(self._nodes[name]):
                    print('Launching step: {}'.format(name))
                    print('Parameters:\n{}'.format(self._nodes[name].job.task_parameter_override))
                    self._running_nodes.append(name)
                else:
                    getLogger('trains.automation.controller').error(
                        'ERROR: Failed launching step \'{}\': {}'.format(name, self._nodes[name]))

            # update current state (in configuration, so that we could later continue an aborted pipeline)
            self._force_task_configuration_update()

            # visualize pipeline state (plot)
            self._update_execution_plot()

            # quit if all pipelines nodes are fully executed.
            if not next_nodes and not self._running_nodes:
                break

        # stop all currently running jobs:
        for node in self._nodes.values():
            if node.job and not node.executed:
                node.job.abort()

        if self._stop_event:
            # noinspection PyBroadException
            try:
                self._stop_event.set()
            except Exception:
                pass

    def __verify_step_reference(self, node, step_ref_string):
        # type: (Node, str) -> bool
        """
        Verify the step reference. For example "${step1.parameters.Args/param}"
        :param Node node: calling reference node (used for logging)
        :param str step_ref_string: For example "${step1.parameters.Args/param}"
        :return: True if valid reference
        """
        parts = step_ref_string[2:-1].split('.')
        v = step_ref_string
        if len(parts) < 2:
            raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v))
        prev_step = parts[0]
        input_type = parts[1]
        if prev_step not in self._nodes:
            raise ValueError("Node '{}', parameter '{}', step name '{}' is invalid".format(node.name, v, prev_step))
        if input_type not in ('artifacts', 'parameters', 'models', 'id'):
            raise ValueError(
                "Node {}, parameter '{}', input type '{}' is invalid".format(node.name, v, input_type))

        if input_type != 'id' and len(parts) < 3:
            raise ValueError("Node '{}', parameter '{}' is invalid".format(node.name, v))

        if input_type == 'models':
            try:
                model_type = parts[2].lower()
            except Exception:
                raise ValueError(
                    "Node '{}', parameter '{}', input type '{}', model_type is missing {}".format(
                        node.name, v, input_type, parts))
            if model_type not in ('input', 'output'):
                raise ValueError(
                    "Node '{}', parameter '{}', input type '{}', "
                    "model_type is invalid (input/output) found {}".format(
                        node.name, v, input_type, model_type))

            if len(parts) < 4:
                raise ValueError(
                    "Node '{}', parameter '{}', input type '{}', model index is missing".format(
                        node.name, v, input_type))

            # check casting
            try:
                int(parts[3])
            except Exception:
                raise ValueError(
                    "Node '{}', parameter '{}', input type '{}', model index is missing {}".format(
                        node.name, v, input_type, parts))

            if len(parts) < 5:
                raise ValueError(
                    "Node '{}', parameter '{}', input type '{}', model property is missing".format(
                        node.name, v, input_type))

            if not hasattr(BaseModel, parts[4]):
                raise ValueError(
                    "Node '{}', parameter '{}', input type '{}', model property is invalid {}".format(
                        node.name, v, input_type, parts[4]))
        return True

    def __parse_step_reference(self, step_ref_string):
        """
        return the adjusted value for "${step...}"
        :param step_ref_string: reference string of the form ${step_name.type.value}"
        :return: str with value
        """
        parts = step_ref_string[2:-1].split('.')
        if len(parts) < 2:
            raise ValueError("Could not parse reference '{}'".format(step_ref_string))
        prev_step = parts[0]
        input_type = parts[1].lower()
        if prev_step not in self._nodes or not self._nodes[prev_step].job:
            raise ValueError("Could not parse reference '{}', step {} could not be found".format(
                step_ref_string, prev_step))
        if input_type not in ('artifacts', 'parameters', 'models', 'id'):
            raise ValueError("Could not parse reference '{}', type {} not valid".format(step_ref_string, input_type))
        if input_type != 'id' and len(parts) < 3:
            raise ValueError("Could not parse reference '{}', missing fields in {}".format(step_ref_string, parts))

        task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \
            else Task.get_task(task_id=self._nodes[prev_step].executed)
        task.reload()
        if input_type == 'artifacts':
            # fix \. to use . in artifacts
            artifact_path = ('.'.join(parts[2:])).replace('\\.', '\\_dot_\\')
            artifact_path = artifact_path.split('.')

            obj = task.artifacts
            for p in artifact_path:
                p = p.replace('\\_dot_\\', '.')
                if isinstance(obj, dict):
                    obj = obj.get(p)
                elif hasattr(obj, p):
                    obj = getattr(obj, p)
                else:
                    raise ValueError("Could not locate artifact {} on previous step {}".format(
                        '.'.join(parts[1:]), prev_step))
            return str(obj)
        elif input_type == 'parameters':
            step_params = task.get_parameters()
            param_name = '.'.join(parts[2:])
            if param_name not in step_params:
                raise ValueError("Could not locate parameter {} on previous step {}".format(
                    '.'.join(parts[1:]), prev_step))
            return step_params.get(param_name)
        elif input_type == 'models':
            model_type = parts[2].lower()
            if model_type not in ('input', 'output'):
                raise ValueError("Could not locate model {} on previous step {}".format(
                    '.'.join(parts[1:]), prev_step))
            try:
                model_idx = int(parts[3])
                model = task.models[model_type][model_idx]
            except Exception:
                raise ValueError("Could not locate model {} on previous step {}, index {} is invalid".format(
                    '.'.join(parts[1:]), prev_step, parts[3]))

            return str(getattr(model, parts[4]))

        elif input_type == 'id':
            return task.id
        return None

    def _parse_step_ref(self, value):
        # type: (Any) -> Optional[str]
        """
        Return the step reference. For example "${step1.parameters.Args/param}"
        :param value: string
        :return:
        """
        # look for all the step references
        pattern = self._step_ref_pattern
        updated_value = value
        if isinstance(value, str):
            for g in pattern.findall(value):
                # update with actual value
                new_val = self.__parse_step_reference(g)
                updated_value = updated_value.replace(g, new_val, 1)
        return updated_value