From dd5e986a52be61143dbf4655143cff7901aa3cf4 Mon Sep 17 00:00:00 2001 From: Justin Hayes Date: Tue, 30 Jul 2024 09:01:13 -0400 Subject: [PATCH] refac: anthropic manifold --- .../providers/anthropic_manifold_pipeline.py | 69 ++++++++++++------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/examples/pipelines/providers/anthropic_manifold_pipeline.py b/examples/pipelines/providers/anthropic_manifold_pipeline.py index 563a6d4..81fa910 100644 --- a/examples/pipelines/providers/anthropic_manifold_pipeline.py +++ b/examples/pipelines/providers/anthropic_manifold_pipeline.py @@ -1,21 +1,20 @@ """ title: Anthropic Manifold Pipeline -author: justinh-rahb +author: justinh-rahb, sriparashiva date: 2024-06-20 -version: 1.3 +version: 1.4 license: MIT description: A pipeline for generating text and processing images using the Anthropic API. -requirements: requests, anthropic +requirements: requests, sseclient-py environment_variables: ANTHROPIC_API_KEY """ import os -from anthropic import Anthropic, RateLimitError, APIStatusError, APIConnectionError - -from schemas import OpenAIChatMessage +import requests +import json from typing import List, Union, Generator, Iterator from pydantic import BaseModel -import requests +import sseclient from utils.pipelines.main import pop_system_message @@ -32,7 +31,15 @@ class Pipeline: self.valves = self.Valves( **{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")} ) - self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY) + self.url = 'https://api.anthropic.com/v1/messages' + self.update_headers() + + def update_headers(self): + self.headers = { + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', + 'x-api-key': self.valves.ANTHROPIC_API_KEY + } def get_anthropic_models(self): return [ @@ -51,8 +58,7 @@ class Pipeline: pass async def on_valves_updated(self): - self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY) - pass + self.update_headers() def pipelines(self) -> List[dict]: return self.get_anthropic_models() @@ -131,21 +137,38 @@ class Pipeline: } if body.get("stream", False): - return self.stream_response(model_id, payload) + return self.stream_response(payload) else: - return self.get_completion(model_id, payload) - except (RateLimitError, APIStatusError, APIConnectionError) as e: + return self.get_completion(payload) + except Exception as e: return f"Error: {e}" - def stream_response(self, model_id: str, payload: dict) -> Generator: - stream = self.client.messages.create(**payload) + def stream_response(self, payload: dict) -> Generator: + response = requests.post(self.url, headers=self.headers, json=payload, stream=True) - for chunk in stream: - if chunk.type == "content_block_start": - yield chunk.content_block.text - elif chunk.type == "content_block_delta": - yield chunk.delta.text + if response.status_code == 200: + client = sseclient.SSEClient(response) + for event in client.events(): + try: + data = json.loads(event.data) + if data["type"] == "content_block_start": + yield data["content_block"]["text"] + elif data["type"] == "content_block_delta": + yield data["delta"]["text"] + elif data["type"] == "message_stop": + break + except json.JSONDecodeError: + print(f"Failed to parse JSON: {event.data}") + except KeyError as e: + print(f"Unexpected data structure: {e}") + print(f"Full data: {data}") + else: + raise Exception(f"Error: {response.status_code} - {response.text}") - def get_completion(self, model_id: str, payload: dict) -> str: - response = self.client.messages.create(**payload) - return response.content[0].text \ No newline at end of file + def get_completion(self, payload: dict) -> str: + response = requests.post(self.url, headers=self.headers, json=payload) + if response.status_code == 200: + res = response.json() + return res["content"][0]["text"] if "content" in res and res["content"] else "" + else: + raise Exception(f"Error: {response.status_code} - {response.text}")