fix: anthropic system

This commit is contained in:
Timothy J. Baek
2024-06-17 10:40:27 -07:00
parent a6daafe1f2
commit 0f610ca7eb
2 changed files with 47 additions and 16 deletions

View File

@@ -17,6 +17,8 @@ from typing import List, Union, Generator, Iterator
from pydantic import BaseModel
import requests
from utils.pipelines.main import pop_system_message
class Pipeline:
class Valves(BaseModel):
@@ -80,6 +82,8 @@ class Pipeline:
def stream_response(
self, model_id: str, messages: List[dict], body: dict
) -> Generator:
system_message, messages = pop_system_message(messages)
max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
@@ -91,14 +95,19 @@ class Pipeline:
stop_sequences = body.get("stop") if body.get("stop") is not None else []
stream = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
stream=True,
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
"stream": True,
}
)
for chunk in stream:
@@ -108,6 +117,8 @@ class Pipeline:
yield chunk.delta.text
def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
system_message, messages = pop_system_message(messages)
max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
@@ -119,12 +130,17 @@ class Pipeline:
stop_sequences = body.get("stop") if body.get("stop") is not None else []
response = self.client.messages.create(
model=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
}
)
return response.content[0].text