"""
title: Cohere Manifold Pipeline
author: justinh-rahb
date: 2024-05-28
version: 1.0
license: MIT
description: A pipeline for generating text using the Anthropic API.
requirements: requests
environment_variables: COHERE_API_KEY
"""

import os
import json
from schemas import OpenAIChatMessage
from typing import List, Union, Generator, Iterator
from pydantic import BaseModel
import requests


class Pipeline:
    def __init__(self):
        self.type = "manifold"
        self.id = "cohere"
        self.name = "cohere/"

        class Valves(BaseModel):
            COHERE_API_BASE_URL: str = "https://api.cohere.com/v1"
            COHERE_API_KEY: str

        self.valves = Valves(**{"COHERE_API_KEY": os.getenv("COHERE_API_KEY")})

        self.pipelines = self.get_cohere_models()

    async def on_startup(self):
        print(f"on_startup:{__name__}")
        pass

    async def on_shutdown(self):
        print(f"on_shutdown:{__name__}")
        pass

    async def on_valves_updated(self):
        # This function is called when the valves are updated.

        self.pipelines = self.get_cohere_models()

        pass

    def get_cohere_models(self):
        if self.valves.COHERE_API_KEY:
            try:
                headers = {}
                headers["Authorization"] = f"Bearer {self.valves.COHERE_API_KEY}"
                headers["Content-Type"] = "application/json"

                r = requests.get(
                    f"{self.valves.COHERE_API_BASE_URL}/models", headers=headers
                )

                models = r.json()
                return [
                    {
                        "id": model["name"],
                        "name": model["name"] if "name" in model else model["name"],
                    }
                    for model in models["models"]
                ]
            except Exception as e:

                print(f"Error: {e}")
                return [
                    {
                        "id": self.id,
                        "name": "Could not fetch models from Cohere, please update the API Key in the valves.",
                    },
                ]
        else:
            return []

    def pipe(
        self, user_message: str, model_id: str, messages: List[dict], body: dict
    ) -> Union[str, Generator, Iterator]:
        try:
            if body.get("stream", False):
                return self.stream_response(user_message, model_id, messages, body)
            else:
                return self.get_completion(user_message, model_id, messages, body)
        except Exception as e:
            return f"Error: {e}"

    def stream_response(
        self, user_message: str, model_id: str, messages: List[dict], body: dict
    ) -> Generator:

        headers = {}
        headers["Authorization"] = f"Bearer {self.valves.COHERE_API_KEY}"
        headers["Content-Type"] = "application/json"

        r = requests.post(
            url=f"{self.valves.COHERE_API_BASE_URL}/chat",
            json={
                "model": model_id,
                "chat_history": [
                    {
                        "role": "USER" if message["role"] == "user" else "CHATBOT",
                        "message": message["content"],
                    }
                    for message in messages[:-1]
                ],
                "message": user_message,
                "stream": True,
            },
            headers=headers,
            stream=True,
        )

        r.raise_for_status()

        for line in r.iter_lines():
            if line:
                try:
                    line = json.loads(line)
                    if line["event_type"] == "text-generation":
                        yield line["text"]
                except:
                    pass

    def get_completion(
        self, user_message: str, model_id: str, messages: List[dict], body: dict
    ) -> str:
        headers = {}
        headers["Authorization"] = f"Bearer {self.valves.COHERE_API_KEY}"
        headers["Content-Type"] = "application/json"

        r = requests.post(
            url=f"{self.valves.COHERE_API_BASE_URL}/chat",
            json={
                "model": model_id,
                "chat_history": [
                    {
                        "role": "USER" if message["role"] == "user" else "CHATBOT",
                        "message": message["content"],
                    }
                    for message in messages[:-1]
                ],
                "message": user_message,
            },
            headers=headers,
        )

        r.raise_for_status()
        data = r.json()

        return data["text"] if "text" in data else "No response from Cohere."