fix: model system prompt variable support

This commit is contained in:
Timothy J. Baek 2024-06-17 13:47:48 -07:00
parent 84bd4994cd
commit 686c5081e6
2 changed files with 56 additions and 15 deletions

View File

@ -40,6 +40,7 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
@ -814,22 +815,35 @@ async def generate_chat_completion(
"num_thread", None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
# Check if the payload already has a system message
# If not, add a system message to the payload
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)
@ -910,22 +924,35 @@ async def generate_openai_chat_completion(
else None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)

View File

@ -20,6 +20,8 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
@ -392,22 +394,34 @@ async def generate_chat_completion(
else None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)