fix: anthropic system

This commit is contained in:
Timothy J. Baek 2024-06-17 10:40:27 -07:00
parent a6daafe1f2
commit 0f610ca7eb
2 changed files with 47 additions and 16 deletions

View File

@ -17,6 +17,8 @@ from typing import List, Union, Generator, Iterator
from pydantic import BaseModel from pydantic import BaseModel
import requests import requests
from utils.pipelines.main import pop_system_message
class Pipeline: class Pipeline:
class Valves(BaseModel): class Valves(BaseModel):
@ -80,6 +82,8 @@ class Pipeline:
def stream_response( def stream_response(
self, model_id: str, messages: List[dict], body: dict self, model_id: str, messages: List[dict], body: dict
) -> Generator: ) -> Generator:
system_message, messages = pop_system_message(messages)
max_tokens = ( max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096 body.get("max_tokens") if body.get("max_tokens") is not None else 4096
) )
@ -91,14 +95,19 @@ class Pipeline:
stop_sequences = body.get("stop") if body.get("stop") is not None else [] stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create( stream = self.client.messages.create(
model=model_id, **{
messages=messages, "model": model_id,
max_tokens=max_tokens, **(
temperature=temperature, {"system": system_message} if system_message else {}
top_k=top_k, ), # Add system message if it exists (optional
top_p=top_p, "messages": messages,
stop_sequences=stop_sequences, "max_tokens": max_tokens,
stream=True, "temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
"stream": True,
}
) )
for chunk in stream: for chunk in stream:
@ -108,6 +117,8 @@ class Pipeline:
yield chunk.delta.text yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str: def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
system_message, messages = pop_system_message(messages)
max_tokens = ( max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096 body.get("max_tokens") if body.get("max_tokens") is not None else 4096
) )
@ -119,12 +130,17 @@ class Pipeline:
stop_sequences = body.get("stop") if body.get("stop") is not None else [] stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create( response = self.client.messages.create(
model=model_id, **{
messages=messages, "model": model_id,
max_tokens=max_tokens, **(
temperature=temperature, {"system": system_message} if system_message else {}
top_k=top_k, ), # Add system message if it exists (optional
top_p=top_p, "messages": messages,
stop_sequences=stop_sequences, "max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
}
) )
return response.content[0].text return response.content[0].text

View File

@ -5,7 +5,7 @@ from typing import List
from schemas import OpenAIChatMessage from schemas import OpenAIChatMessage
import inspect import inspect
from typing import get_type_hints, Literal from typing import get_type_hints, Literal, Tuple
def stream_message_template(model: str, message: str): def stream_message_template(model: str, message: str):
@ -47,6 +47,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None return None
def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)
def add_or_update_system_message(content: str, messages: List[dict]): def add_or_update_system_message(content: str, messages: List[dict]):
""" """
Adds a new system message at the beginning of the messages list Adds a new system message at the beginning of the messages list