""" title: Anthropic Manifold Pipeline author: justinh-rahb date: 2024-05-27 version: 1.0 license: MIT description: A pipeline for generating text using the Anthropic API. requirements: requests, anthropic environment_variables: ANTHROPIC_API_KEY """ import os from anthropic import Anthropic, RateLimitError, APIStatusError, APIConnectionError from schemas import OpenAIChatMessage from typing import List, Union, Generator, Iterator from pydantic import BaseModel import requests class Pipeline: def __init__(self): self.type = "manifold" self.id = "anthropic" self.name = "anthropic/" class Valves(BaseModel): ANTHROPIC_API_KEY: str self.valves = Valves(**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY")}) self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY) def get_anthropic_models(self): # In the future, this could fetch models dynamically from Anthropic return [ {"id": "claude-3-haiku-20240307", "name": "claude-3-haiku"}, {"id": "claude-3-opus-20240229", "name": "claude-3-opus"}, {"id": "claude-3-sonnet-20240229", "name": "claude-3-sonnet"}, # Add other Anthropic models here as they become available ] async def on_startup(self): print(f"on_startup:{__name__}") pass async def on_shutdown(self): print(f"on_shutdown:{__name__}") pass async def on_valves_updated(self): # This function is called when the valves are updated. self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY) pass # Pipelines are the models that are available in the manifold. # It can be a list or a function that returns a list. def pipelines(self) -> List[dict]: return self.get_anthropic_models() def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: try: if body.get("stream", False): return self.stream_response(model_id, messages, body) else: return self.get_completion(model_id, messages, body) except (RateLimitError, APIStatusError, APIConnectionError) as e: return f"Error: {e}" def stream_response( self, model_id: str, messages: List[dict], body: dict ) -> Generator: max_tokens = ( body.get("max_tokens") if body.get("max_tokens") is not None else 4096 ) temperature = ( body.get("temperature") if body.get("temperature") is not None else 0.8 ) top_k = body.get("top_k") if body.get("top_k") is not None else 40 top_p = body.get("top_p") if body.get("top_p") is not None else 0.9 stop_sequences = body.get("stop") if body.get("stop") is not None else [] stream = self.client.messages.create( model=model_id, messages=messages, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_sequences=stop_sequences, stream=True, ) for chunk in stream: if chunk.type == "content_block_start": yield chunk.content_block.text elif chunk.type == "content_block_delta": yield chunk.delta.text def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: max_tokens = ( body.get("max_tokens") if body.get("max_tokens") is not None else 4096 ) temperature = ( body.get("temperature") if body.get("temperature") is not None else 0.8 ) top_k = body.get("top_k") if body.get("top_k") is not None else 40 top_p = body.get("top_p") if body.get("top_p") is not None else 0.9 stop_sequences = body.get("stop") if body.get("stop") is not None else [] response = self.client.messages.create( model=model_id, messages=messages, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_sequences=stop_sequences, ) return response.content[0].text