from copy import deepcopy
from typing import Sequence, Optional

from packaging.version import parse

from apiserver.tests.automated import TestService


class TestTaskModels(TestService):
    def setUp(self, version="2.13"):
        super().setUp(version=version)

    def test_new_apis(self):
        # no models
        empty_task = self.new_task()
        self.assertModels(empty_task, [], [])

        id1, id2 = self.new_model("model1"), self.new_model("model2")
        input_models = [
            {"name": "input1", "model": id1},
            {"name": "input2", "model": id2},
        ]
        output_models = [
            {"name": "output1", "model": "id3"},
            {"name": "output2", "model": "id4"},
        ]

        # task creation with models
        task = self.new_task(models={"input": input_models, "output": output_models})
        self.assertModels(task, input_models, output_models)

        # add_or_update existing model
        res = self.api.tasks.add_or_update_model(
            task=task, name="input1", type="input", model="Test"
        )
        self.assertEqual(res.updated, 1)
        modified_input = deepcopy(input_models)
        modified_input[0]["model"] = "Test"
        self.assertModels(task, modified_input, output_models)

        # add_or_update new mode
        res = self.api.tasks.add_or_update_model(
            task=task, name="output3", type="output", model="TestOutput"
        )
        self.assertEqual(res.updated, 1)
        modified_output = deepcopy(output_models)
        modified_output.append({"name": "output3", "model": "TestOutput"})
        self.assertModels(task, modified_input, modified_output)

        # task editing
        self.api.tasks.edit(
            task=task, models={"input": input_models, "output": output_models}
        )
        self.assertModels(task, input_models, output_models)

        # delete models
        res = self.api.tasks.delete_models(
            task=task,
            models=[
                {"name": "input1", "type": "input"},
                {"name": "input2", "type": "input"},
                {"name": "output1", "type": "output"},
                {"name": "not_existing", "type": "output"},
            ]
        )
        self.assertEqual(res.updated, 1)
        self.assertModels(task, [], output_models[1:])

    def assertModels(
        self, task_id: str, input_models: Sequence[dict], output_models: Sequence[dict],
    ):
        def get_model_id(model: dict) -> Optional[str]:
            if not model:
                return None
            id_ = model.get("model")
            if isinstance(id_, str):
                return id_
            if id_ is None or id_ == {}:
                return None
            return id_.get("id")

        def compare_models(actual: Sequence[dict], expected: Sequence[dict]):
            self.assertEqual(
                [(m["name"], get_model_id(m)) for m in actual],
                [(m["name"], m["model"]) for m in expected],
            )

        for task in (
            self.api.tasks.get_all_ex(id=task_id).tasks[0],
            self.api.tasks.get_all(id=task_id).tasks[0],
            self.api.tasks.get_by_id(task=task_id).task,
        ):
            compare_models(task.models.input, input_models)
            compare_models(task.models.output, output_models)
            if self._version < parse("2.13"):
                self.assertEqual(
                    get_model_id(task.execution),
                    input_models[0]["model"] if input_models else None,
                )
                self.assertEqual(
                    get_model_id(task.output),
                    output_models[-1]["model"] if output_models else None,
                )

    def new_task(self, **kwargs):
        self.update_missing(
            kwargs, type="testing", name="test task models"
        )
        return self.create_temp("tasks", **kwargs)

    def new_model(self, name: str, **kwargs):
        return self.create_temp(
            "models", uri="file://test", name=name, labels={}, **kwargs
        )