from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
    ELASTICSEARCH_URL,
    ELASTICSEARCH_CA_CERTS,
    ELASTICSEARCH_API_KEY,
    ELASTICSEARCH_USERNAME,
    ELASTICSEARCH_PASSWORD,
    ELASTICSEARCH_CLOUD_ID,
    ELASTICSEARCH_INDEX_PREFIX,
    SSL_ASSERT_FINGERPRINT,
)


class ElasticsearchClient:
    """
    Important:
    in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
    an index for each file but store it as a text field, while seperating to different index
    baesd on the embedding length.
    """

    def __init__(self):
        self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
        self.client = Elasticsearch(
            hosts=[ELASTICSEARCH_URL],
            ca_certs=ELASTICSEARCH_CA_CERTS,
            api_key=ELASTICSEARCH_API_KEY,
            cloud_id=ELASTICSEARCH_CLOUD_ID,
            basic_auth=(
                (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
                if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
                else None
            ),
            ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
        )

    # Status: works
    def _get_index_name(self, dimension: int) -> str:
        return f"{self.index_prefix}_d{str(dimension)}"

    # Status: works
    def _scan_result_to_get_result(self, result) -> GetResult:
        if not result:
            return None
        ids = []
        documents = []
        metadatas = []

        for hit in result:
            ids.append(hit["_id"])
            documents.append(hit["_source"].get("text"))
            metadatas.append(hit["_source"].get("metadata"))

        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])

    # Status: works
    def _result_to_get_result(self, result) -> GetResult:
        if not result["hits"]["hits"]:
            return None
        ids = []
        documents = []
        metadatas = []

        for hit in result["hits"]["hits"]:
            ids.append(hit["_id"])
            documents.append(hit["_source"].get("text"))
            metadatas.append(hit["_source"].get("metadata"))

        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])

    # Status: works
    def _result_to_search_result(self, result) -> SearchResult:
        ids = []
        distances = []
        documents = []
        metadatas = []

        for hit in result["hits"]["hits"]:
            ids.append(hit["_id"])
            distances.append(hit["_score"])
            documents.append(hit["_source"].get("text"))
            metadatas.append(hit["_source"].get("metadata"))

        return SearchResult(
            ids=[ids],
            distances=[distances],
            documents=[documents],
            metadatas=[metadatas],
        )

    # Status: works
    def _create_index(self, dimension: int):
        body = {
            "mappings": {
                "dynamic_templates": [
                    {
                        "strings": {
                            "match_mapping_type": "string",
                            "mapping": {"type": "keyword"},
                        }
                    }
                ],
                "properties": {
                    "collection": {"type": "keyword"},
                    "id": {"type": "keyword"},
                    "vector": {
                        "type": "dense_vector",
                        "dims": dimension,  # Adjust based on your vector dimensions
                        "index": True,
                        "similarity": "cosine",
                    },
                    "text": {"type": "text"},
                    "metadata": {"type": "object"},
                },
            }
        }
        self.client.indices.create(index=self._get_index_name(dimension), body=body)

    # Status: works

    def _create_batches(self, items: list[VectorItem], batch_size=100):
        for i in range(0, len(items), batch_size):
            yield items[i : min(i + batch_size, len(items))]

    # Status: works
    def has_collection(self, collection_name) -> bool:
        query_body = {"query": {"bool": {"filter": []}}}
        query_body["query"]["bool"]["filter"].append(
            {"term": {"collection": collection_name}}
        )

        try:
            result = self.client.count(index=f"{self.index_prefix}*", body=query_body)

            return result.body["count"] > 0
        except Exception as e:
            return None

    def delete_collection(self, collection_name: str):
        query = {"query": {"term": {"collection": collection_name}}}
        self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)

    # Status: works
    def search(
        self, collection_name: str, vectors: list[list[float]], limit: int
    ) -> Optional[SearchResult]:
        query = {
            "size": limit,
            "_source": ["text", "metadata"],
            "query": {
                "script_score": {
                    "query": {
                        "bool": {"filter": [{"term": {"collection": collection_name}}]}
                    },
                    "script": {
                        "source": "cosineSimilarity(params.vector, 'vector') + 1.0",
                        "params": {
                            "vector": vectors[0]
                        },  # Assuming single query vector
                    },
                }
            },
        }

        result = self.client.search(
            index=self._get_index_name(len(vectors[0])), body=query
        )

        return self._result_to_search_result(result)

    # Status: only tested halfwat
    def query(
        self, collection_name: str, filter: dict, limit: Optional[int] = None
    ) -> Optional[GetResult]:
        if not self.has_collection(collection_name):
            return None

        query_body = {
            "query": {"bool": {"filter": []}},
            "_source": ["text", "metadata"],
        }

        for field, value in filter.items():
            query_body["query"]["bool"]["filter"].append({"term": {field: value}})
        query_body["query"]["bool"]["filter"].append(
            {"term": {"collection": collection_name}}
        )
        size = limit if limit else 10

        try:
            result = self.client.search(
                index=f"{self.index_prefix}*",
                body=query_body,
                size=size,
            )

            return self._result_to_get_result(result)

        except Exception as e:
            return None

    # Status: works
    def _has_index(self, dimension: int):
        return self.client.indices.exists(
            index=self._get_index_name(dimension=dimension)
        )

    def get_or_create_index(self, dimension: int):
        if not self._has_index(dimension=dimension):
            self._create_index(dimension=dimension)

    # Status: works
    def get(self, collection_name: str) -> Optional[GetResult]:
        # Get all the items in the collection.
        query = {
            "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
            "_source": ["text", "metadata"],
        }
        results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))

        return self._scan_result_to_get_result(results)

    # Status: works
    def insert(self, collection_name: str, items: list[VectorItem]):
        if not self._has_index(dimension=len(items[0]["vector"])):
            self._create_index(dimension=len(items[0]["vector"]))

        for batch in self._create_batches(items):
            actions = [
                {
                    "_index": self._get_index_name(dimension=len(items[0]["vector"])),
                    "_id": item["id"],
                    "_source": {
                        "collection": collection_name,
                        "vector": item["vector"],
                        "text": item["text"],
                        "metadata": item["metadata"],
                    },
                }
                for item in batch
            ]
            bulk(self.client, actions)

    # Upsert documents using the update API with doc_as_upsert=True.
    def upsert(self, collection_name: str, items: list[VectorItem]):
        if not self._has_index(dimension=len(items[0]["vector"])):
            self._create_index(dimension=len(items[0]["vector"]))
        for batch in self._create_batches(items):
            actions = [
                {
                    "_op_type": "update",
                    "_index": self._get_index_name(dimension=len(item["vector"])),
                    "_id": item["id"],
                    "doc": {
                        "collection": collection_name,
                        "vector": item["vector"],
                        "text": item["text"],
                        "metadata": item["metadata"],
                    },
                    "doc_as_upsert": True,
                }
                for item in batch
            ]
            bulk(self.client, actions)

    # Delete specific documents from a collection by filtering on both collection and document IDs.
    def delete(
        self,
        collection_name: str,
        ids: Optional[list[str]] = None,
        filter: Optional[dict] = None,
    ):

        query = {
            "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
        }
        # logic based on chromaDB
        if ids:
            query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
        elif filter:
            for field, value in filter.items():
                query["query"]["bool"]["filter"].append(
                    {"term": {f"metadata.{field}": value}}
                )

        self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)

    def reset(self):
        indices = self.client.indices.get(index=f"{self.index_prefix}*")
        for index in indices:
            self.client.indices.delete(index=index)