mirror of
https://github.com/open-webui/pipelines
synced 2025-05-12 00:20:48 +00:00
fix: anthropic system
This commit is contained in:
parent
a6daafe1f2
commit
0f610ca7eb
@ -17,6 +17,8 @@ from typing import List, Union, Generator, Iterator
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from utils.pipelines.main import pop_system_message
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
class Valves(BaseModel):
|
class Valves(BaseModel):
|
||||||
@ -80,6 +82,8 @@ class Pipeline:
|
|||||||
def stream_response(
|
def stream_response(
|
||||||
self, model_id: str, messages: List[dict], body: dict
|
self, model_id: str, messages: List[dict], body: dict
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
|
system_message, messages = pop_system_message(messages)
|
||||||
|
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
|
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
|
||||||
)
|
)
|
||||||
@ -91,14 +95,19 @@ class Pipeline:
|
|||||||
stop_sequences = body.get("stop") if body.get("stop") is not None else []
|
stop_sequences = body.get("stop") if body.get("stop") is not None else []
|
||||||
|
|
||||||
stream = self.client.messages.create(
|
stream = self.client.messages.create(
|
||||||
model=model_id,
|
**{
|
||||||
messages=messages,
|
"model": model_id,
|
||||||
max_tokens=max_tokens,
|
**(
|
||||||
temperature=temperature,
|
{"system": system_message} if system_message else {}
|
||||||
top_k=top_k,
|
), # Add system message if it exists (optional
|
||||||
top_p=top_p,
|
"messages": messages,
|
||||||
stop_sequences=stop_sequences,
|
"max_tokens": max_tokens,
|
||||||
stream=True,
|
"temperature": temperature,
|
||||||
|
"top_k": top_k,
|
||||||
|
"top_p": top_p,
|
||||||
|
"stop_sequences": stop_sequences,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
@ -108,6 +117,8 @@ class Pipeline:
|
|||||||
yield chunk.delta.text
|
yield chunk.delta.text
|
||||||
|
|
||||||
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
|
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
|
||||||
|
system_message, messages = pop_system_message(messages)
|
||||||
|
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
|
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
|
||||||
)
|
)
|
||||||
@ -119,12 +130,17 @@ class Pipeline:
|
|||||||
stop_sequences = body.get("stop") if body.get("stop") is not None else []
|
stop_sequences = body.get("stop") if body.get("stop") is not None else []
|
||||||
|
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=model_id,
|
**{
|
||||||
messages=messages,
|
"model": model_id,
|
||||||
max_tokens=max_tokens,
|
**(
|
||||||
temperature=temperature,
|
{"system": system_message} if system_message else {}
|
||||||
top_k=top_k,
|
), # Add system message if it exists (optional
|
||||||
top_p=top_p,
|
"messages": messages,
|
||||||
stop_sequences=stop_sequences,
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_k": top_k,
|
||||||
|
"top_p": top_p,
|
||||||
|
"stop_sequences": stop_sequences,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
@ -5,7 +5,7 @@ from typing import List
|
|||||||
from schemas import OpenAIChatMessage
|
from schemas import OpenAIChatMessage
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import get_type_hints, Literal
|
from typing import get_type_hints, Literal, Tuple
|
||||||
|
|
||||||
|
|
||||||
def stream_message_template(model: str, message: str):
|
def stream_message_template(model: str, message: str):
|
||||||
@ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_message(messages: List[dict]) -> dict:
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
return message
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remove_system_message(messages: List[dict]) -> List[dict]:
|
||||||
|
return [message for message in messages if message["role"] != "system"]
|
||||||
|
|
||||||
|
|
||||||
|
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
|
||||||
|
return get_system_message(messages), remove_system_message(messages)
|
||||||
|
|
||||||
|
|
||||||
def add_or_update_system_message(content: str, messages: List[dict]):
|
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||||
"""
|
"""
|
||||||
Adds a new system message at the beginning of the messages list
|
Adds a new system message at the beginning of the messages list
|
||||||
|
Loading…
Reference in New Issue
Block a user