diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 44b5667d5..e45ea8897 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -28,11 +28,15 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict from starlette.background import BackgroundTask + + from open_webui.utils.misc import ( + calculate_sha256, +) +from open_webui.utils.payload import ( apply_model_params_to_body_ollama, apply_model_params_to_body_openai, apply_model_system_prompt_to_body, - calculate_sha256, ) from open_webui.utils.utils import get_admin_user, get_verified_user diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 53d1f4534..e8fd81d45 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -26,10 +26,13 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, StreamingResponse from pydantic import BaseModel from starlette.background import BackgroundTask -from open_webui.utils.misc import ( + + +from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) + from open_webui.utils.utils import get_admin_user, get_verified_user log = logging.getLogger(__name__) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 45fe3cad9..6c6f197dd 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -51,11 +51,15 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from open_webui.utils.misc import ( - apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, openai_chat_chunk_message_template, openai_chat_completion_message_template, ) +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) + + from open_webui.utils.tools import get_tools app = FastAPI() diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 8b72983f1..d1b340044 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -6,7 +6,14 @@ from datetime import timedelta from pathlib import Path from typing import Callable, Optional -from open_webui.utils.task import prompt_template + +def get_messages_content(messages: list[dict]) -> str: + return "\n".join( + [ + f"{message['role'].upper()}: {get_content_from_message(message)}" + for message in messages + ] + ) def get_last_user_message_item(messages: list[dict]) -> Optional[dict]: @@ -30,7 +37,6 @@ def get_last_user_message(messages: list[dict]) -> Optional[str]: message = get_last_user_message_item(messages) if message is None: return None - return get_content_from_message(message) @@ -114,88 +120,6 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict: return template -# inplace function: form_data is modified -def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: - system = params.get("system", None) - if not system: - return form_data - - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - system = prompt_template(system, **template_params) - form_data["messages"] = add_or_update_system_message( - system, form_data.get("messages", []) - ) - return form_data - - -# inplace function: form_data is modified -def apply_model_params_to_body( - params: dict, form_data: dict, mappings: dict[str, Callable] -) -> dict: - if not params: - return form_data - - for key, cast_func in mappings.items(): - if (value := params.get(key)) is not None: - form_data[key] = cast_func(value) - - return form_data - - -# inplace function: form_data is modified -def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - } - return apply_model_params_to_body(params, form_data, mappings) - - -def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: - opts = [ - "temperature", - "top_p", - "seed", - "mirostat", - "mirostat_eta", - "mirostat_tau", - "num_ctx", - "num_batch", - "num_keep", - "repeat_last_n", - "tfs_z", - "top_k", - "min_p", - "use_mmap", - "use_mlock", - "num_thread", - "num_gpu", - ] - mappings = {i: lambda x: x for i in opts} - form_data = apply_model_params_to_body(params, form_data, mappings) - - name_differences = { - "max_tokens": "num_predict", - "frequency_penalty": "repeat_penalty", - } - - for key, value in name_differences.items(): - if (param := params.get(key, None)) is not None: - form_data[value] = param - - return form_data - - def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py new file mode 100644 index 000000000..227cca45f --- /dev/null +++ b/backend/open_webui/utils/payload.py @@ -0,0 +1,88 @@ +from open_webui.utils.task import prompt_template +from open_webui.utils.misc import ( + add_or_update_system_message, +) + +from typing import Callable, Optional + + +# inplace function: form_data is modified +def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict: + system = params.get("system", None) + if not system: + return form_data + + if user: + template_params = { + "user_name": user.name, + "user_location": user.info.get("location") if user.info else None, + } + else: + template_params = {} + system = prompt_template(system, **template_params) + form_data["messages"] = add_or_update_system_message( + system, form_data.get("messages", []) + ) + return form_data + + +# inplace function: form_data is modified +def apply_model_params_to_body( + params: dict, form_data: dict, mappings: dict[str, Callable] +) -> dict: + if not params: + return form_data + + for key, cast_func in mappings.items(): + if (value := params.get(key)) is not None: + form_data[key] = cast_func(value) + + return form_data + + +# inplace function: form_data is modified +def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: + mappings = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], + } + return apply_model_params_to_body(params, form_data, mappings) + + +def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: + opts = [ + "temperature", + "top_p", + "seed", + "mirostat", + "mirostat_eta", + "mirostat_tau", + "num_ctx", + "num_batch", + "num_keep", + "repeat_last_n", + "tfs_z", + "top_k", + "min_p", + "use_mmap", + "use_mlock", + "num_thread", + "num_gpu", + ] + mappings = {i: lambda x: x for i in opts} + form_data = apply_model_params_to_body(params, form_data, mappings) + + name_differences = { + "max_tokens": "num_predict", + "frequency_penalty": "repeat_penalty", + } + + for key, value in name_differences.items(): + if (param := params.get(key, None)) is not None: + form_data[value] = param + + return form_data