pipelines/examples/pipelines/providers/anthropic_manifold_pipeline.py
2024-06-20 12:42:49 -04:00

148 lines
5.1 KiB
Python

"""
title: Anthropic Manifold Pipeline
author: justinh-rahb
date: 2024-06-20
version: 1.1
license: MIT
description: A pipeline for generating text using the Anthropic API.
requirements: requests, anthropic
environment_variables: ANTHROPIC_API_KEY
"""
import os
from anthropic import Anthropic, RateLimitError, APIStatusError, APIConnectionError
from schemas import OpenAIChatMessage
from typing import List, Union, Generator, Iterator
from pydantic import BaseModel
import requests
from utils.pipelines.main import pop_system_message
class Pipeline:
class Valves(BaseModel):
ANTHROPIC_API_KEY: str = ""
def __init__(self):
self.type = "manifold"
self.id = "anthropic"
self.name = "anthropic/"
self.valves = self.Valves(
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")}
)
self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY)
def get_anthropic_models(self):
# In the future, this could fetch models dynamically from Anthropic
return [
{"id": "claude-3-haiku-20240307", "name": "claude-3-haiku"},
{"id": "claude-3-opus-20240229", "name": "claude-3-opus"},
{"id": "claude-3-sonnet-20240229", "name": "claude-3-sonnet"},
{"id": "claude-3-5-sonnet-20240620", "name": "claude-3.5-sonnet"},
# Add other Anthropic models here as they become available
]
async def on_startup(self):
print(f"on_startup:{__name__}")
pass
async def on_shutdown(self):
print(f"on_shutdown:{__name__}")
pass
async def on_valves_updated(self):
# This function is called when the valves are updated.
self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY)
pass
# Pipelines are the models that are available in the manifold.
# It can be a list or a function that returns a list.
def pipelines(self) -> List[dict]:
return self.get_anthropic_models()
def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
try:
if "user" in body:
del body["user"]
if "chat_id" in body:
del body["chat_id"]
if "title" in body:
del body["title"]
if body.get("stream", False):
return self.stream_response(model_id, messages, body)
else:
return self.get_completion(model_id, messages, body)
except (RateLimitError, APIStatusError, APIConnectionError) as e:
return f"Error: {e}"
def stream_response(
self, model_id: str, messages: List[dict], body: dict
) -> Generator:
system_message, messages = pop_system_message(messages)
max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
temperature = (
body.get("temperature") if body.get("temperature") is not None else 0.8
)
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create(
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
"stream": True,
}
)
for chunk in stream:
if chunk.type == "content_block_start":
yield chunk.content_block.text
elif chunk.type == "content_block_delta":
yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
system_message, messages = pop_system_message(messages)
max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
temperature = (
body.get("temperature") if body.get("temperature") is not None else 0.8
)
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create(
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
}
)
return response.content[0].text