diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index e64995af5..adfb82f2b 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,7 +19,11 @@ from apps.webui.models.functions import Functions from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id -from utils.misc import stream_message_template, whole_message_template +from utils.misc import ( + stream_message_template, + whole_message_template, + add_or_update_system_message, +) from utils.task import prompt_template @@ -47,8 +51,6 @@ from config import ( from apps.socket.main import get_event_call, get_event_emitter import inspect -import uuid -import time import json from typing import Iterator, Generator, AsyncGenerator @@ -287,6 +289,7 @@ def get_extra_params(metadata: dict): } +# inplace function: form_data is modified def add_model_params(params: dict, form_data: dict) -> dict: if not params: return form_data @@ -307,44 +310,40 @@ def add_model_params(params: dict, form_data: dict) -> dict: return form_data +# inplace function: form_data is modified +def populate_system_message(params: dict, form_data: dict, user) -> dict: + system = params.get("system", None) + if not system: + return form_data + + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + system = prompt_template(system, **template_params) + form_data["messages"] = add_or_update_system_message( + system, form_data.get("messages", []) + ) + return form_data + + async def generate_function_chat_completion(form_data, user): - print("entry point") model_id = form_data.get("model") model_info = Models.get_model_by_id(model_id) - metadata = form_data.pop("metadata", None) - extra_params = get_extra_params(metadata) + # Add extra params such as __event_emitter__ + extra_params = get_extra_params(metadata) if model_info: if model_info.base_model_id: form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - system = params.get("system", None) form_data = add_model_params(params, form_data) - - if system: - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - - system = prompt_template(system, **template_params) - - # Check if the payload already has a system message - # If not, add a system message to the payload - for message in form_data.get("messages", []): - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - if form_data.get("messages"): - form_data["messages"].insert( - 0, {"role": "system", "content": system} - ) + form_data = populate_system_message(params, form_data, user) async def job(): pipe_id = get_pipe_id(form_data)