Merge pull request #17 from justinh-rahb/patch-1

This commit is contained in:
Timothy Jaeryang Baek 2024-05-29 10:00:04 -07:00 committed by GitHub
commit da3e00a1b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 19 deletions

View File

@ -1,5 +1,5 @@
""" """
title: Anthropic Pipeline title: Anthropic Manifold Pipeline
author: justinh-rahb author: justinh-rahb
date: 2024-05-27 date: 2024-05-27
version: 1.0 version: 1.0
@ -22,7 +22,7 @@ class Pipeline:
def __init__(self): def __init__(self):
self.type = "manifold" self.type = "manifold"
self.id = "anthropic" self.id = "anthropic"
self.name = "Anthropic/" self.name = "anthropic/"
class Valves(BaseModel): class Valves(BaseModel):
ANTHROPIC_API_KEY: str ANTHROPIC_API_KEY: str
@ -71,14 +71,20 @@ class Pipeline:
def stream_response( def stream_response(
self, model_id: str, messages: List[dict], body: dict self, model_id: str, messages: List[dict], body: dict
) -> Generator: ) -> Generator:
max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096
temperature = body.get("temperature") if body.get("temperature") is not None else 0.8
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create( stream = self.client.messages.create(
model=model_id, model=model_id,
messages=messages, messages=messages,
max_tokens=body.get("max_tokens", 1024), max_tokens=max_tokens,
temperature=body.get("temperature", 1.0), temperature=temperature,
top_k=body.get("top_k", 40), top_k=top_k,
top_p=body.get("top_p", 0.9), top_p=top_p,
stop_sequences=body.get("stop", []), stop_sequences=stop_sequences,
stream=True, stream=True,
) )
@ -89,13 +95,19 @@ class Pipeline:
yield chunk.delta.text yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
max_tokens = body.get("max_tokens") if body.get("max_tokens") is not None else 4096
temperature = body.get("temperature") if body.get("temperature") is not None else 0.8
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create( response = self.client.messages.create(
model=model_id, model=model_id,
messages=messages, messages=messages,
max_tokens=body.get("max_tokens", 1024), max_tokens=max_tokens,
temperature=body.get("temperature", 1.0), temperature=temperature,
top_k=body.get("top_k", 40), top_k=top_k,
top_p=body.get("top_p", 0.9), top_p=top_p,
stop_sequences=body.get("stop", []), stop_sequences=stop_sequences,
) )
return response.content[0].text return response.content[0].text

View File

@ -1,11 +1,11 @@
""" """
title: Anthropic Pipeline title: Cohere Manifold Pipeline
author: justinh-rahb author: justinh-rahb
date: 2024-05-27 date: 2024-05-28
version: 1.0 version: 1.0
license: MIT license: MIT
description: A pipeline for generating text using the Anthropic API. description: A pipeline for generating text using the Anthropic API.
dependencies: requests, anthropic dependencies: requests
environment_variables: COHERE_API_KEY environment_variables: COHERE_API_KEY
""" """
@ -20,8 +20,8 @@ import requests
class Pipeline: class Pipeline:
def __init__(self): def __init__(self):
self.type = "manifold" self.type = "manifold"
self.id = "cohere_manifold" self.id = "cohere"
self.name = "Cohere/" self.name = "cohere/"
class Valves(BaseModel): class Valves(BaseModel):
COHERE_API_BASE_URL: str = "https://api.cohere.com/v1" COHERE_API_BASE_URL: str = "https://api.cohere.com/v1"

View File

@ -74,9 +74,9 @@ class Pipeline:
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
# Extract and validate parameters from the request body # Extract and validate parameters from the request body
max_tokens = body.get("max_tokens", 1024) max_tokens = body.get("max_tokens", 4096)
if not isinstance(max_tokens, int) or max_tokens < 0: if not isinstance(max_tokens, int) or max_tokens < 0:
max_tokens = 1024 # Default to 1024 if invalid max_tokens = 4096 # Default to 4096 if invalid
temperature = body.get("temperature", 0.8) temperature = body.get("temperature", 0.8)
if not isinstance(temperature, (int, float)) or temperature < 0: if not isinstance(temperature, (int, float)) or temperature < 0:

View File

@ -4,3 +4,5 @@ pydantic==2.7.1
python-multipart==0.0.9 python-multipart==0.0.9
requests==2.32.2 requests==2.32.2
aiohttp==3.9.5 aiohttp==3.9.5
passlib==1.7.4
jwt==1.3.1