refac: use add_or_update_system_message

This commit is contained in:
Michael Poluektov 2024-07-31 17:16:07 +01:00
parent 006fc3495e
commit baf58ef396

View File

@ -19,7 +19,11 @@ from apps.webui.models.functions import Functions
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template, whole_message_template from utils.misc import (
stream_message_template,
whole_message_template,
add_or_update_system_message,
)
from utils.task import prompt_template from utils.task import prompt_template
@ -47,8 +51,6 @@ from config import (
from apps.socket.main import get_event_call, get_event_emitter from apps.socket.main import get_event_call, get_event_emitter
import inspect import inspect
import uuid
import time
import json import json
from typing import Iterator, Generator, AsyncGenerator from typing import Iterator, Generator, AsyncGenerator
@ -287,6 +289,7 @@ def get_extra_params(metadata: dict):
} }
# inplace function: form_data is modified
def add_model_params(params: dict, form_data: dict) -> dict: def add_model_params(params: dict, form_data: dict) -> dict:
if not params: if not params:
return form_data return form_data
@ -307,23 +310,12 @@ def add_model_params(params: dict, form_data: dict) -> dict:
return form_data return form_data
async def generate_function_chat_completion(form_data, user): # inplace function: form_data is modified
print("entry point") def populate_system_message(params: dict, form_data: dict, user) -> dict:
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
extra_params = get_extra_params(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
system = params.get("system", None) system = params.get("system", None)
form_data = add_model_params(params, form_data) if not system:
return form_data
if system:
if user: if user:
template_params = { template_params = {
"user_name": user.name, "user_name": user.name,
@ -331,20 +323,27 @@ async def generate_function_chat_completion(form_data, user):
} }
else: else:
template_params = {} template_params = {}
system = prompt_template(system, **template_params) system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
# Check if the payload already has a system message system, form_data.get("messages", [])
# If not, add a system message to the payload
for message in form_data.get("messages", []):
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
if form_data.get("messages"):
form_data["messages"].insert(
0, {"role": "system", "content": system}
) )
return form_data
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", None)
# Add extra params such as __event_emitter__
extra_params = get_extra_params(metadata)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = add_model_params(params, form_data)
form_data = populate_system_message(params, form_data, user)
async def job(): async def job():
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)