mirror of
https://github.com/open-webui/pipelines
synced 2025-05-10 15:40:55 +00:00
fix: anthropic system
This commit is contained in:
parent
a6daafe1f2
commit
0f610ca7eb
@ -17,6 +17,8 @@ from typing import List, Union, Generator, Iterator
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
|
||||
from utils.pipelines.main import pop_system_message
|
||||
|
||||
|
||||
class Pipeline:
|
||||
class Valves(BaseModel):
|
||||
@ -80,6 +82,8 @@ class Pipeline:
|
||||
def stream_response(
|
||||
self, model_id: str, messages: List[dict], body: dict
|
||||
) -> Generator:
|
||||
system_message, messages = pop_system_message(messages)
|
||||
|
||||
max_tokens = (
|
||||
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
|
||||
)
|
||||
@ -91,14 +95,19 @@ class Pipeline:
|
||||
stop_sequences = body.get("stop") if body.get("stop") is not None else []
|
||||
|
||||
stream = self.client.messages.create(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=True,
|
||||
**{
|
||||
"model": model_id,
|
||||
**(
|
||||
{"system": system_message} if system_message else {}
|
||||
), # Add system message if it exists (optional
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"stop_sequences": stop_sequences,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
@ -108,6 +117,8 @@ class Pipeline:
|
||||
yield chunk.delta.text
|
||||
|
||||
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
|
||||
system_message, messages = pop_system_message(messages)
|
||||
|
||||
max_tokens = (
|
||||
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
|
||||
)
|
||||
@ -119,12 +130,17 @@ class Pipeline:
|
||||
stop_sequences = body.get("stop") if body.get("stop") is not None else []
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
stop_sequences=stop_sequences,
|
||||
**{
|
||||
"model": model_id,
|
||||
**(
|
||||
{"system": system_message} if system_message else {}
|
||||
), # Add system message if it exists (optional
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"stop_sequences": stop_sequences,
|
||||
}
|
||||
)
|
||||
return response.content[0].text
|
||||
|
@ -5,7 +5,7 @@ from typing import List
|
||||
from schemas import OpenAIChatMessage
|
||||
|
||||
import inspect
|
||||
from typing import get_type_hints, Literal
|
||||
from typing import get_type_hints, Literal, Tuple
|
||||
|
||||
|
||||
def stream_message_template(model: str, message: str):
|
||||
@ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
|
||||
return None
|
||||
|
||||
|
||||
def get_system_message(messages: List[dict]) -> dict:
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def remove_system_message(messages: List[dict]) -> List[dict]:
|
||||
return [message for message in messages if message["role"] != "system"]
|
||||
|
||||
|
||||
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
|
||||
return get_system_message(messages), remove_system_message(messages)
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
|
Loading…
Reference in New Issue
Block a user