mirror of
https://github.com/open-webui/pipelines
synced 2025-05-10 15:40:55 +00:00
Merge pull request #443 from mikeyobrien/main
Add thinking support for claude-3-7-sonnet
This commit is contained in:
commit
f89ab37f53
@ -6,7 +6,7 @@ version: 1.4
|
||||
license: MIT
|
||||
description: A pipeline for generating text and processing images using the Anthropic API.
|
||||
requirements: requests, sseclient-py
|
||||
environment_variables: ANTHROPIC_API_KEY
|
||||
environment_variables: ANTHROPIC_API_KEY, ANTHROPIC_THINKING_BUDGET_TOKENS, ANTHROPIC_ENABLE_THINKING
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -18,6 +18,17 @@ import sseclient
|
||||
|
||||
from utils.pipelines.main import pop_system_message
|
||||
|
||||
REASONING_EFFORT_BUDGET_TOKEN_MAP = {
|
||||
"none": None,
|
||||
"low": 1024,
|
||||
"medium": 4096,
|
||||
"high": 16384,
|
||||
"max": 32768,
|
||||
}
|
||||
|
||||
# Maximum combined token limit for Claude 3.7
|
||||
MAX_COMBINED_TOKENS = 64000
|
||||
|
||||
|
||||
class Pipeline:
|
||||
class Valves(BaseModel):
|
||||
@ -29,16 +40,20 @@ class Pipeline:
|
||||
self.name = "anthropic/"
|
||||
|
||||
self.valves = self.Valves(
|
||||
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")}
|
||||
**{
|
||||
"ANTHROPIC_API_KEY": os.getenv(
|
||||
"ANTHROPIC_API_KEY", "your-api-key-here"
|
||||
),
|
||||
}
|
||||
)
|
||||
self.url = 'https://api.anthropic.com/v1/messages'
|
||||
self.url = "https://api.anthropic.com/v1/messages"
|
||||
self.update_headers()
|
||||
|
||||
def update_headers(self):
|
||||
self.headers = {
|
||||
'anthropic-version': '2023-06-01',
|
||||
'content-type': 'application/json',
|
||||
'x-api-key': self.valves.ANTHROPIC_API_KEY
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
"x-api-key": self.valves.ANTHROPIC_API_KEY,
|
||||
}
|
||||
|
||||
def get_anthropic_models(self):
|
||||
@ -88,7 +103,7 @@ class Pipeline:
|
||||
) -> Union[str, Generator, Iterator]:
|
||||
try:
|
||||
# Remove unnecessary keys
|
||||
for key in ['user', 'chat_id', 'title']:
|
||||
for key in ["user", "chat_id", "title"]:
|
||||
body.pop(key, None)
|
||||
|
||||
system_message, messages = pop_system_message(messages)
|
||||
@ -102,28 +117,40 @@ class Pipeline:
|
||||
if isinstance(message.get("content"), list):
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
processed_content.append({"type": "text", "text": item["text"]})
|
||||
processed_content.append(
|
||||
{"type": "text", "text": item["text"]}
|
||||
)
|
||||
elif item["type"] == "image_url":
|
||||
if image_count >= 5:
|
||||
raise ValueError("Maximum of 5 images per API call exceeded")
|
||||
raise ValueError(
|
||||
"Maximum of 5 images per API call exceeded"
|
||||
)
|
||||
|
||||
processed_image = self.process_image(item["image_url"])
|
||||
processed_content.append(processed_image)
|
||||
|
||||
if processed_image["source"]["type"] == "base64":
|
||||
image_size = len(processed_image["source"]["data"]) * 3 / 4
|
||||
image_size = (
|
||||
len(processed_image["source"]["data"]) * 3 / 4
|
||||
)
|
||||
else:
|
||||
image_size = 0
|
||||
|
||||
total_image_size += image_size
|
||||
if total_image_size > 100 * 1024 * 1024:
|
||||
raise ValueError("Total size of images exceeds 100 MB limit")
|
||||
raise ValueError(
|
||||
"Total size of images exceeds 100 MB limit"
|
||||
)
|
||||
|
||||
image_count += 1
|
||||
else:
|
||||
processed_content = [{"type": "text", "text": message.get("content", "")}]
|
||||
processed_content = [
|
||||
{"type": "text", "text": message.get("content", "")}
|
||||
]
|
||||
|
||||
processed_messages.append({"role": message["role"], "content": processed_content})
|
||||
processed_messages.append(
|
||||
{"role": message["role"], "content": processed_content}
|
||||
)
|
||||
|
||||
# Prepare the payload
|
||||
payload = {
|
||||
@ -139,6 +166,42 @@ class Pipeline:
|
||||
}
|
||||
|
||||
if body.get("stream", False):
|
||||
supports_thinking = "claude-3-7" in model_id
|
||||
reasoning_effort = body.get("reasoning_effort", "none")
|
||||
budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)
|
||||
|
||||
# Allow users to input an integer value representing budget tokens
|
||||
if (
|
||||
not budget_tokens
|
||||
and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
|
||||
):
|
||||
try:
|
||||
budget_tokens = int(reasoning_effort)
|
||||
except ValueError as e:
|
||||
print("Failed to convert reasoning effort to int", e)
|
||||
budget_tokens = None
|
||||
|
||||
if supports_thinking and budget_tokens:
|
||||
# Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
|
||||
max_tokens = payload.get("max_tokens", 4096)
|
||||
combined_tokens = budget_tokens + max_tokens
|
||||
|
||||
if combined_tokens > MAX_COMBINED_TOKENS:
|
||||
error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
|
||||
print(error_message)
|
||||
return error_message
|
||||
|
||||
payload["max_tokens"] = combined_tokens
|
||||
payload["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
# Thinking requires temperature 1.0 and does not support top_p, top_k
|
||||
payload["temperature"] = 1.0
|
||||
if "top_k" in payload:
|
||||
del payload["top_k"]
|
||||
if "top_p" in payload:
|
||||
del payload["top_p"]
|
||||
return self.stream_response(payload)
|
||||
else:
|
||||
return self.get_completion(payload)
|
||||
@ -146,7 +209,12 @@ class Pipeline:
|
||||
return f"Error: {e}"
|
||||
|
||||
def stream_response(self, payload: dict) -> Generator:
|
||||
response = requests.post(self.url, headers=self.headers, json=payload, stream=True)
|
||||
"""Used for title and tag generation"""
|
||||
try:
|
||||
response = requests.post(
|
||||
self.url, headers=self.headers, json=payload, stream=True
|
||||
)
|
||||
print(f"{response} for {payload}")
|
||||
|
||||
if response.status_code == 200:
|
||||
client = sseclient.SSEClient(response)
|
||||
@ -154,23 +222,51 @@ class Pipeline:
|
||||
try:
|
||||
data = json.loads(event.data)
|
||||
if data["type"] == "content_block_start":
|
||||
if data["content_block"]["type"] == "thinking":
|
||||
yield "<think>"
|
||||
else:
|
||||
yield data["content_block"]["text"]
|
||||
elif data["type"] == "content_block_delta":
|
||||
if data["delta"]["type"] == "thinking_delta":
|
||||
yield data["delta"]["thinking"]
|
||||
elif data["delta"]["type"] == "signature_delta":
|
||||
yield "\n </think> \n\n"
|
||||
else:
|
||||
yield data["delta"]["text"]
|
||||
elif data["type"] == "message_stop":
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
print(f"Failed to parse JSON: {event.data}")
|
||||
yield f"Error: Failed to parse JSON response"
|
||||
except KeyError as e:
|
||||
print(f"Unexpected data structure: {e}")
|
||||
print(f"Unexpected data structure: {e} for payload {payload}")
|
||||
print(f"Full data: {data}")
|
||||
yield f"Error: Unexpected data structure: {e}"
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
error_message = f"Error: {response.status_code} - {response.text}"
|
||||
print(error_message)
|
||||
yield error_message
|
||||
except Exception as e:
|
||||
error_message = f"Error: {str(e)}"
|
||||
print(error_message)
|
||||
yield error_message
|
||||
|
||||
def get_completion(self, payload: dict) -> str:
|
||||
try:
|
||||
response = requests.post(self.url, headers=self.headers, json=payload)
|
||||
print(response, payload)
|
||||
if response.status_code == 200:
|
||||
res = response.json()
|
||||
return res["content"][0]["text"] if "content" in res and res["content"] else ""
|
||||
for content in res["content"]:
|
||||
if not content.get("text"):
|
||||
continue
|
||||
return content["text"]
|
||||
return ""
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
error_message = f"Error: {response.status_code} - {response.text}"
|
||||
print(error_message)
|
||||
return error_message
|
||||
except Exception as e:
|
||||
error_message = f"Error: {str(e)}"
|
||||
print(error_message)
|
||||
return error_message
|
||||
|
Loading…
Reference in New Issue
Block a user