refac: use add_or_update_system_message

This commit is contained in:
Michael Poluektov 2024-07-31 17:16:07 +01:00
parent 006fc3495e
commit baf58ef396

View File

@ -19,7 +19,11 @@ from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template, whole_message_template
from utils.misc import (
stream_message_template,
whole_message_template,
add_or_update_system_message,
)
from utils.task import prompt_template
@ -47,8 +51,6 @@ from config import (
from apps.socket.main import get_event_call, get_event_emitter
import inspect
import uuid
import time
import json
from typing import Iterator, Generator, AsyncGenerator
@ -287,6 +289,7 @@ def get_extra_params(metadata: dict):
}
# inplace function: form_data is modified
def add_model_params(params: dict, form_data: dict) -> dict:
if not params:
return form_data
@ -307,44 +310,40 @@ def add_model_params(params: dict, form_data: dict) -> dict:
return form_data
# inplace function: form_data is modified
def populate_system_message(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
async def generate_function_chat_completion(form_data, user):
print("entry point")
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
extra_params = get_extra_params(metadata)
# Add extra params such as __event_emitter__
extra_params = get_extra_params(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
system = params.get("system", None)
form_data = add_model_params(params, form_data)
if system:
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
# Check if the payload already has a system message
# If not, add a system message to the payload
for message in form_data.get("messages", []):
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
if form_data.get("messages"):
form_data["messages"].insert(
0, {"role": "system", "content": system}
)
form_data = populate_system_message(params, form_data, user)
async def job():
pipe_id = get_pipe_id(form_data)