diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f1544c80b..19d914c4b 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -44,7 +44,13 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import calculate_sha256, add_or_update_system_message +from utils.misc import ( + apply_model_params_to_body_ollama, + calculate_sha256, + add_or_update_system_message, + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -699,6 +705,18 @@ class GenerateChatCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None +def get_ollama_url(url_idx: Optional[int], model: str): + if url_idx is None: + if model not in app.state.MODELS: + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), + ) + url_idx = random.choice(app.state.MODELS[model]["urls"]) + url = app.state.config.OLLAMA_BASE_URLS[url_idx] + return url + + @app.post("/api/chat") @app.post("/api/chat/{url_idx}") async def generate_chat_completion( @@ -706,17 +724,12 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) + log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=") payload = { **form_data.model_dump(exclude_none=True, exclude=["metadata"]), } - if "metadata" in payload: - del payload["metadata"] + payload.pop("metadata") model_id = form_data.model model_info = Models.get_model_by_id(model_id) @@ -731,148 +744,15 @@ async def generate_chat_completion( if payload.get("options") is None: payload["options"] = {} - if ( - params.get("mirostat", None) - and payload["options"].get("mirostat") is None - ): - payload["options"]["mirostat"] = params.get("mirostat", None) - - if ( - params.get("mirostat_eta", None) - and payload["options"].get("mirostat_eta") is None - ): - payload["options"]["mirostat_eta"] = params.get("mirostat_eta", None) - - if ( - params.get("mirostat_tau", None) - and payload["options"].get("mirostat_tau") is None - ): - payload["options"]["mirostat_tau"] = params.get("mirostat_tau", None) - - if ( - params.get("num_ctx", None) - and payload["options"].get("num_ctx") is None - ): - payload["options"]["num_ctx"] = params.get("num_ctx", None) - - if ( - params.get("num_batch", None) - and payload["options"].get("num_batch") is None - ): - payload["options"]["num_batch"] = params.get("num_batch", None) - - if ( - params.get("num_keep", None) - and payload["options"].get("num_keep") is None - ): - payload["options"]["num_keep"] = params.get("num_keep", None) - - if ( - params.get("repeat_last_n", None) - and payload["options"].get("repeat_last_n") is None - ): - payload["options"]["repeat_last_n"] = params.get("repeat_last_n", None) - - if ( - params.get("frequency_penalty", None) - and payload["options"].get("frequency_penalty") is None - ): - payload["options"]["repeat_penalty"] = params.get( - "frequency_penalty", None - ) - - if ( - params.get("temperature", None) is not None - and payload["options"].get("temperature") is None - ): - payload["options"]["temperature"] = params.get("temperature", None) - - if ( - params.get("seed", None) is not None - and payload["options"].get("seed") is None - ): - payload["options"]["seed"] = params.get("seed", None) - - if params.get("stop", None) and payload["options"].get("stop") is None: - payload["options"]["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in params["stop"] - ] - if params.get("stop", None) - else None - ) - - if params.get("tfs_z", None) and payload["options"].get("tfs_z") is None: - payload["options"]["tfs_z"] = params.get("tfs_z", None) - - if ( - params.get("max_tokens", None) - and payload["options"].get("max_tokens") is None - ): - payload["options"]["num_predict"] = params.get("max_tokens", None) - - if params.get("top_k", None) and payload["options"].get("top_k") is None: - payload["options"]["top_k"] = params.get("top_k", None) - - if params.get("top_p", None) and payload["options"].get("top_p") is None: - payload["options"]["top_p"] = params.get("top_p", None) - - if params.get("min_p", None) and payload["options"].get("min_p") is None: - payload["options"]["min_p"] = params.get("min_p", None) - - if ( - params.get("use_mmap", None) - and payload["options"].get("use_mmap") is None - ): - payload["options"]["use_mmap"] = params.get("use_mmap", None) - - if ( - params.get("use_mlock", None) - and payload["options"].get("use_mlock") is None - ): - payload["options"]["use_mlock"] = params.get("use_mlock", None) - - if ( - params.get("num_thread", None) - and payload["options"].get("num_thread") is None - ): - payload["options"]["num_thread"] = params.get("num_thread", None) - - system = params.get("system", None) - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), + payload["options"] = apply_model_params_to_body_ollama( + params, payload["options"] ) + payload = apply_model_system_prompt_to_body(params, payload, user) - if payload.get("messages"): - payload["messages"] = add_or_update_system_message( - system, payload["messages"] - ) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if url_idx is None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") log.debug(payload) @@ -906,83 +786,27 @@ async def generate_openai_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - form_data = OpenAIChatCompletionForm(**form_data) - payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} + completion_form = OpenAIChatCompletionForm(**form_data) + payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} + payload.pop("metadata") - if "metadata" in payload: - del payload["metadata"] - - model_id = form_data.model + model_id = completion_form.model model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id - model_info.params = model_info.params.model_dump() + params = model_info.params.model_dump() - if model_info.params: - payload["temperature"] = model_info.params.get("temperature", None) - payload["top_p"] = model_info.params.get("top_p", None) - payload["max_tokens"] = model_info.params.get("max_tokens", None) - payload["frequency_penalty"] = model_info.params.get( - "frequency_penalty", None - ) - payload["seed"] = model_info.params.get("seed", None) - payload["stop"] = ( - [ - bytes(stop, "utf-8").decode("unicode_escape") - for stop in model_info.params["stop"] - ] - if model_info.params.get("stop", None) - else None - ) + if params: + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_model_system_prompt_to_body(params, payload, user) - system = model_info.params.get("system", None) + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if system: - system = prompt_template( - system, - **( - { - "user_name": user.name, - "user_location": ( - user.info.get("location") if user.info else None - ), - } - if user - else {} - ), - ) - # Check if the payload already has a system message - # If not, add a system message to the payload - if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) - - if url_idx is None: - if ":" not in payload["model"]: - payload["model"] = f"{payload['model']}:latest" - - if payload["model"] in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) - else: - raise HTTPException( - status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), - ) - - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") return await post_streaming_url( diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index a0d8f3750..1313d2091 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -17,7 +17,10 @@ from utils.utils import ( get_verified_user, get_admin_user, ) -from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body +from utils.misc import ( + apply_model_params_to_body_openai, + apply_model_system_prompt_to_body, +) from config import ( SRC_LOG_LEVELS, @@ -366,7 +369,7 @@ async def generate_chat_completion( payload["model"] = model_info.base_model_id params = model_info.params.model_dump() - payload = apply_model_params_to_body(params, payload) + payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_system_prompt_to_body(params, payload, user) model = app.state.MODELS[payload.get("model")] diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index a0b9f5008..6848fdd4d 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id from utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, - apply_model_params_to_body, + apply_model_params_to_body_openai, apply_model_system_prompt_to_body, ) @@ -289,7 +289,7 @@ async def generate_function_chat_completion(form_data, user): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = apply_model_params_to_body(params, form_data) + form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_model_system_prompt_to_body(params, form_data, user) pipe_id = get_pipe_id(form_data) diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 25dd4dd5b..ffe6a6e53 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -2,7 +2,7 @@ from pathlib import Path import hashlib import re from datetime import timedelta -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Callable import uuid import time @@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di # inplace function: form_data is modified -def apply_model_params_to_body(params: dict, form_data: dict) -> dict: +def apply_model_params_to_body( + params: dict, form_data: dict, mappings: dict[str, Callable] +) -> dict: if not params: return form_data - mappings = { - "temperature": float, - "top_p": int, - "max_tokens": int, - "frequency_penalty": int, - "seed": lambda x: x, - "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - } - for key, cast_func in mappings.items(): if (value := params.get(key)) is not None: form_data[key] = cast_func(value) @@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict: return form_data +OPENAI_MAPPINGS = { + "temperature": float, + "top_p": int, + "max_tokens": int, + "frequency_penalty": int, + "seed": lambda x: x, + "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], +} + + +# inplace function: form_data is modified +def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict: + return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS) + + +def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: + opts = [ + "mirostat", + "mirostat_eta", + "mirostat_tau", + "num_ctx", + "num_batch", + "num_keep", + "repeat_last_n", + "tfs_z", + "top_k", + "min_p", + "use_mmap", + "use_mlock", + "num_thread", + ] + mappings = {i: lambda x: x for i in opts} + mappings = {**mappings, **OPENAI_MAPPINGS} + return apply_model_params_to_body(params, form_data, mappings) + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters