# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create sample PR curve summary data.
We have 3 classes: R, G, and B. We generate colors within RGB space from 3
normal distributions (1 at each corner of the color triangle: [255, 0, 0],
[0, 255, 0], and [0, 0, 255]).
The true label of each random color is associated with the normal distribution
that generated it.
Using 3 other normal distributions (over the distance each color is from a
corner of the color triangle - RGB), we then compute the probability that each
color belongs to the class. We use those probabilities to generate PR curves.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path

from tempfile import gettempdir
from absl import app
from absl import flags
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorboard.plugins.pr_curve import summary
from clearml import Task


# Connecting ClearML with the current process,
# from here on everything is logged automatically
task = Task.init(project_name='examples', task_name='tensorboard pr_curve')

tf.compat.v1.disable_v2_behavior()
FLAGS = flags.FLAGS

flags.DEFINE_string(
    "logdir",
    os.path.join(gettempdir(), "pr_curve_demo"),
    "Directory into which to write TensorBoard data.",
)

flags.DEFINE_integer(
    "steps", 10, "Number of steps to generate for each PR curve."
)


def start_runs(
    logdir, steps, run_name, thresholds, mask_every_other_prediction=False
):
    """Generate a PR curve with precision and recall evenly weighted.
    Arguments:
      logdir: The directory into which to store all the runs' data.
      steps: The number of steps to run for.
      run_name: The name of the run.
      thresholds: The number of thresholds to use for PR curves.
      mask_every_other_prediction: Whether to mask every other prediction by
        alternating weights between 0 and 1.
    """
    tf.compat.v1.reset_default_graph()
    tf.compat.v1.set_random_seed(42)

    # Create a normal distribution layer used to generate true color labels.
    distribution = tf.compat.v1.distributions.Normal(loc=0.0, scale=142.0)

    # Sample the distribution to generate colors. Lets generate different numbers
    # of each color. The first dimension is the count of examples.

    # The calls to sample() are given fixed random seed values that are "magic"
    # in that they correspond to the default seeds for those ops when the PR
    # curve test (which depends on this code) was written. We've pinned these
    # instead of continuing to use the defaults since the defaults are based on
    # node IDs from the sequence of nodes added to the graph, which can silently
    # change when this code or any TF op implementations it uses are modified.

    # TODO(nickfelt): redo the PR curve test to avoid reliance on random seeds.

    # Generate reds.
    number_of_reds = 100
    true_reds = tf.clip_by_value(
        tf.concat(
            [
                255 - tf.abs(distribution.sample([number_of_reds, 1], seed=11)),
                tf.abs(distribution.sample([number_of_reds, 2], seed=34)),
            ],
            axis=1,
        ),
        0,
        255,
    )

    # Generate greens.
    number_of_greens = 200
    true_greens = tf.clip_by_value(
        tf.concat(
            [
                tf.abs(distribution.sample([number_of_greens, 1], seed=61)),
                255
                - tf.abs(distribution.sample([number_of_greens, 1], seed=82)),
                tf.abs(distribution.sample([number_of_greens, 1], seed=105)),
            ],
            axis=1,
        ),
        0,
        255,
    )

    # Generate blues.
    number_of_blues = 150
    true_blues = tf.clip_by_value(
        tf.concat(
            [
                tf.abs(distribution.sample([number_of_blues, 2], seed=132)),
                255
                - tf.abs(distribution.sample([number_of_blues, 1], seed=153)),
            ],
            axis=1,
        ),
        0,
        255,
    )

    # Assign each color a vector of 3 booleans based on its true label.
    labels = tf.concat(
        [
            tf.tile(tf.constant([[True, False, False]]), (number_of_reds, 1)),
            tf.tile(tf.constant([[False, True, False]]), (number_of_greens, 1)),
            tf.tile(tf.constant([[False, False, True]]), (number_of_blues, 1)),
        ],
        axis=0,
    )

    # We introduce 3 normal distributions. They are used to predict whether a
    # color falls under a certain class (based on distances from corners of the
    # color triangle). The distributions vary per color. We have the distributions
    # narrow over time.
    initial_standard_deviations = [v + FLAGS.steps for v in (158, 200, 242)]
    iteration = tf.compat.v1.placeholder(tf.int32, shape=[])
    red_predictor = tf.compat.v1.distributions.Normal(
        loc=0.0,
        scale=tf.cast(
            initial_standard_deviations[0] - iteration, dtype=tf.float32
        ),
    )
    green_predictor = tf.compat.v1.distributions.Normal(
        loc=0.0,
        scale=tf.cast(
            initial_standard_deviations[1] - iteration, dtype=tf.float32
        ),
    )
    blue_predictor = tf.compat.v1.distributions.Normal(
        loc=0.0,
        scale=tf.cast(
            initial_standard_deviations[2] - iteration, dtype=tf.float32
        ),
    )

    # Make predictions (assign 3 probabilities to each color based on each color's
    # distance to each of the 3 corners). We seek double the area in the right
    # tail of the normal distribution.
    examples = tf.concat([true_reds, true_greens, true_blues], axis=0)
    probabilities_colors_are_red = (
        1
        - red_predictor.cdf(
            tf.norm(tensor=examples - tf.constant([255.0, 0, 0]), axis=1)
        )
    ) * 2
    probabilities_colors_are_green = (
        1
        - green_predictor.cdf(
            tf.norm(tensor=examples - tf.constant([0, 255.0, 0]), axis=1)
        )
    ) * 2
    probabilities_colors_are_blue = (
        1
        - blue_predictor.cdf(
            tf.norm(tensor=examples - tf.constant([0, 0, 255.0]), axis=1)
        )
    ) * 2

    predictions = (
        probabilities_colors_are_red,
        probabilities_colors_are_green,
        probabilities_colors_are_blue,
    )

    # This is the crucial piece. We write data required for generating PR curves.
    # We create 1 summary per class because we create 1 PR curve per class.
    for i, color in enumerate(("red", "green", "blue")):
        description = (
            "The probabilities used to create this PR curve are "
            "generated from a normal distribution. Its standard "
            "deviation is initially %0.0f and decreases over time."
            % initial_standard_deviations[i]
        )

        weights = None
        if mask_every_other_prediction:
            # Assign a weight of 0 to every even-indexed prediction. Odd-indexed
            # predictions are assigned a default weight of 1.
            consecutive_indices = tf.reshape(
                tf.range(tf.size(input=predictions[i])),
                tf.shape(input=predictions[i]),
            )
            weights = tf.cast(consecutive_indices % 2, dtype=tf.float32)

        summary.op(
            name=color,
            labels=labels[:, i],
            predictions=predictions[i],
            num_thresholds=thresholds,
            weights=weights,
            display_name="classifying %s" % color,
            description=description,
        )
    merged_summary_op = tf.compat.v1.summary.merge_all()
    events_directory = os.path.join(logdir, run_name)
    sess = tf.compat.v1.Session()
    writer = tf.compat.v1.summary.FileWriter(events_directory, sess.graph)

    for step in xrange(steps):
        feed_dict = {
            iteration: step,
        }
        merged_summary = sess.run(merged_summary_op, feed_dict=feed_dict)
        writer.add_summary(merged_summary, step)

    writer.close()


def run_all(logdir, steps, thresholds, verbose=False):
    """Generate PR curve summaries.
    Arguments:
      logdir: The directory into which to store all the runs' data.
      steps: The number of steps to run for.
      verbose: Whether to print the names of runs into stdout during execution.
      thresholds: The number of thresholds to use for PR curves.
    """
    # First, we generate data for a PR curve that assigns even weights for
    # predictions of all classes.
    run_name = "colors"
    if verbose:
        print("--- Running: %s" % run_name)
    start_runs(
        logdir=logdir, steps=steps, run_name=run_name, thresholds=thresholds
    )

    # Next, we generate data for a PR curve that assigns arbitrary weights to
    # predictions.
    run_name = "mask_every_other_prediction"
    if verbose:
        print("--- Running: %s" % run_name)
    start_runs(
        logdir=logdir,
        steps=steps,
        run_name=run_name,
        thresholds=thresholds,
        mask_every_other_prediction=True,
    )


def main(unused_argv):
    print("Saving output to %s." % FLAGS.logdir)
    run_all(FLAGS.logdir, FLAGS.steps, 50, verbose=True)
    print("Done. Output saved to %s." % FLAGS.logdir)


if __name__ == "__main__":
    app.run(main)