fix: anthropic system

This commit is contained in:
Timothy J. Baek 2024-06-17 10:40:27 -07:00
parent a6daafe1f2
commit 0f610ca7eb
2 changed files with 47 additions and 16 deletions

View File

@ -17,6 +17,8 @@ from typing import List, Union, Generator, Iterator
from pydantic import BaseModel
import requests
from utils.pipelines.main import pop_system_message
class Pipeline:
class Valves(BaseModel):
@ -80,6 +82,8 @@ class Pipeline:
def stream_response(
self, model_id: str, messages: List[dict], body: dict
) -> Generator:
system_message, messages = pop_system_message(messages)
max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
@ -91,14 +95,19 @@ class Pipeline:
stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
stream=True,
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
"stream": True,
}
)
for chunk in stream:
@ -108,6 +117,8 @@ class Pipeline:
yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
system_message, messages = pop_system_message(messages)
max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
@ -119,12 +130,17 @@ class Pipeline:
stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
}
)
return response.content[0].text

View File

@ -5,7 +5,7 @@ from typing import List
from schemas import OpenAIChatMessage
import inspect
from typing import get_type_hints, Literal
from typing import get_type_hints, Literal, Tuple
def stream_message_template(model: str, message: str):
@ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None
def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list