From a917293db57c6983f8e802da0022cbca53663994 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 28 May 2024 18:04:45 -0700 Subject: [PATCH] feat: cohere manifold example Co-Authored-By: Justin Hayes --- .../examples/cohere_manifold_pipeline.py | 152 ++++++++++++++++++ pipelines/ollama_manifold_pipeline.py | 2 +- 2 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 pipelines/examples/cohere_manifold_pipeline.py diff --git a/pipelines/examples/cohere_manifold_pipeline.py b/pipelines/examples/cohere_manifold_pipeline.py new file mode 100644 index 0000000..7cbd121 --- /dev/null +++ b/pipelines/examples/cohere_manifold_pipeline.py @@ -0,0 +1,152 @@ +""" +title: Anthropic Pipeline +author: justinh-rahb +date: 2024-05-27 +version: 1.0 +license: MIT +description: A pipeline for generating text using the Anthropic API. +dependencies: requests, anthropic +environment_variables: COHERE_API_KEY +""" + +import os +import json +from schemas import OpenAIChatMessage +from typing import List, Union, Generator, Iterator +from pydantic import BaseModel +import requests + + +class Pipeline: + def __init__(self): + self.type = "manifold" + self.id = "cohere_manifold" + self.name = "Cohere/" + + class Valves(BaseModel): + COHERE_API_BASE_URL: str = "https://api.cohere.com/v1" + COHERE_API_KEY: str + + self.valves = Valves(**{"COHERE_API_KEY": os.getenv("COHERE_API_KEY")}) + + self.pipelines = self.get_cohere_models() + + async def on_startup(self): + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + print(f"on_shutdown:{__name__}") + pass + + async def on_valves_update(self): + self.pipelines = self.get_cohere_models() + + pass + + def get_cohere_models(self): + if self.valves.COHERE_API_KEY: + try: + headers = {} + headers["Authorization"] = f"Bearer {self.valves.COHERE_API_KEY}" + headers["Content-Type"] = "application/json" + + r = requests.get( + f"{self.valves.COHERE_API_BASE_URL}/models", headers=headers + ) + + models = r.json() + return [ + { + "id": model["name"], + "name": model["name"] if "name" in model else model["name"], + } + for model in models["models"] + ] + except Exception as e: + + print(f"Error: {e}") + return [ + { + "id": self.id, + "name": "Could not fetch models from Cohere, please update the API Key in the valves.", + }, + ] + else: + return [] + + def pipe( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> Union[str, Generator, Iterator]: + try: + if body.get("stream", False): + return self.stream_response(user_message, model_id, messages, body) + else: + return self.get_completion(user_message, model_id, messages, body) + except Exception as e: + return f"Error: {e}" + + def stream_response( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> Generator: + + headers = {} + headers["Authorization"] = f"Bearer {self.valves.COHERE_API_KEY}" + headers["Content-Type"] = "application/json" + + r = requests.post( + url=f"{self.valves.COHERE_API_BASE_URL}/chat", + json={ + "model": model_id, + "chat_history": [ + { + "role": "USER" if message["role"] == "user" else "CHATBOT", + "message": message["content"], + } + for message in messages[:-1] + ], + "message": user_message, + "stream": True, + }, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + for line in r.iter_lines(): + if line: + try: + line = json.loads(line) + if line["event_type"] == "text-generation": + yield line["text"] + except: + pass + + def get_completion( + self, user_message: str, model_id: str, messages: List[dict], body: dict + ) -> str: + headers = {} + headers["Authorization"] = f"Bearer {self.valves.COHERE_API_KEY}" + headers["Content-Type"] = "application/json" + + r = requests.post( + url=f"{self.valves.COHERE_API_BASE_URL}/chat", + json={ + "model": model_id, + "chat_history": [ + { + "role": "USER" if message["role"] == "user" else "CHATBOT", + "message": message["content"], + } + for message in messages[:-1] + ], + "message": user_message, + }, + headers=headers, + ) + + r.raise_for_status() + data = r.json() + + return data["text"] if "text" in data else "No response from Cohere." diff --git a/pipelines/ollama_manifold_pipeline.py b/pipelines/ollama_manifold_pipeline.py index 1621eee..0c58cbb 100644 --- a/pipelines/ollama_manifold_pipeline.py +++ b/pipelines/ollama_manifold_pipeline.py @@ -23,7 +23,7 @@ class Pipeline: class Valves(BaseModel): OLLAMA_BASE_URL: str - self.valves = Valves(**{"OLLAMA_BASE_URL": "http://localhost:11435"}) + self.valves = Valves(**{"OLLAMA_BASE_URL": "http://localhost:11434"}) pass async def on_startup(self):