Merge pull request #179 from justinh-rahb/anthropic-fix

Refactor Anthropic Manifold for improved reliability
This commit is contained in:
Justin Hayes 2024-07-31 03:54:09 -04:00 committed by GitHub
commit c76d24b032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,21 +1,20 @@
"""
title: Anthropic Manifold Pipeline
author: justinh-rahb
author: justinh-rahb, sriparashiva
date: 2024-06-20
version: 1.3
version: 1.4
license: MIT
description: A pipeline for generating text and processing images using the Anthropic API.
requirements: requests, anthropic
requirements: requests, sseclient-py
environment_variables: ANTHROPIC_API_KEY
"""
import os
from anthropic import Anthropic, RateLimitError, APIStatusError, APIConnectionError
from schemas import OpenAIChatMessage
import requests
import json
from typing import List, Union, Generator, Iterator
from pydantic import BaseModel
import requests
import sseclient
from utils.pipelines.main import pop_system_message
@ -32,7 +31,15 @@ class Pipeline:
self.valves = self.Valves(
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")}
)
self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY)
self.url = 'https://api.anthropic.com/v1/messages'
self.update_headers()
def update_headers(self):
self.headers = {
'anthropic-version': '2023-06-01',
'content-type': 'application/json',
'x-api-key': self.valves.ANTHROPIC_API_KEY
}
def get_anthropic_models(self):
return [
@ -51,8 +58,7 @@ class Pipeline:
pass
async def on_valves_updated(self):
self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY)
pass
self.update_headers()
def pipelines(self) -> List[dict]:
return self.get_anthropic_models()
@ -131,21 +137,38 @@ class Pipeline:
}
if body.get("stream", False):
return self.stream_response(model_id, payload)
return self.stream_response(payload)
else:
return self.get_completion(model_id, payload)
except (RateLimitError, APIStatusError, APIConnectionError) as e:
return self.get_completion(payload)
except Exception as e:
return f"Error: {e}"
def stream_response(self, model_id: str, payload: dict) -> Generator:
stream = self.client.messages.create(**payload)
def stream_response(self, payload: dict) -> Generator:
response = requests.post(self.url, headers=self.headers, json=payload, stream=True)
for chunk in stream:
if chunk.type == "content_block_start":
yield chunk.content_block.text
elif chunk.type == "content_block_delta":
yield chunk.delta.text
if response.status_code == 200:
client = sseclient.SSEClient(response)
for event in client.events():
try:
data = json.loads(event.data)
if data["type"] == "content_block_start":
yield data["content_block"]["text"]
elif data["type"] == "content_block_delta":
yield data["delta"]["text"]
elif data["type"] == "message_stop":
break
except json.JSONDecodeError:
print(f"Failed to parse JSON: {event.data}")
except KeyError as e:
print(f"Unexpected data structure: {e}")
print(f"Full data: {data}")
else:
raise Exception(f"Error: {response.status_code} - {response.text}")
def get_completion(self, model_id: str, payload: dict) -> str:
response = self.client.messages.create(**payload)
return response.content[0].text
def get_completion(self, payload: dict) -> str:
response = requests.post(self.url, headers=self.headers, json=payload)
if response.status_code == 200:
res = response.json()
return res["content"][0]["text"] if "content" in res and res["content"] else ""
else:
raise Exception(f"Error: {response.status_code} - {response.text}")