mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	refac: re-use utils.misc
This commit is contained in:
		
							parent
							
								
									44c781f414
								
							
						
					
					
						commit
						fc31267a54
					
				@ -44,7 +44,13 @@ from config import (
 | 
			
		||||
    UPLOAD_DIR,
 | 
			
		||||
    AppConfig,
 | 
			
		||||
)
 | 
			
		||||
from utils.misc import calculate_sha256, add_or_update_system_message
 | 
			
		||||
from utils.misc import (
 | 
			
		||||
    apply_model_params_to_body_ollama,
 | 
			
		||||
    calculate_sha256,
 | 
			
		||||
    add_or_update_system_message,
 | 
			
		||||
    apply_model_params_to_body_openai,
 | 
			
		||||
    apply_model_system_prompt_to_body,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
 | 
			
		||||
@ -699,6 +705,18 @@ class GenerateChatCompletionForm(BaseModel):
 | 
			
		||||
    keep_alive: Optional[Union[int, str]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_ollama_url(url_idx: Optional[int], model: str):
 | 
			
		||||
    if url_idx is None:
 | 
			
		||||
        if model not in app.state.MODELS:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=400,
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
 | 
			
		||||
            )
 | 
			
		||||
        url_idx = random.choice(app.state.MODELS[model]["urls"])
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    return url
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/api/chat")
 | 
			
		||||
@app.post("/api/chat/{url_idx}")
 | 
			
		||||
async def generate_chat_completion(
 | 
			
		||||
@ -706,17 +724,12 @@ async def generate_chat_completion(
 | 
			
		||||
    url_idx: Optional[int] = None,
 | 
			
		||||
    user=Depends(get_verified_user),
 | 
			
		||||
):
 | 
			
		||||
    log.debug(
 | 
			
		||||
        "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
 | 
			
		||||
            form_data.model_dump_json(exclude_none=True).encode()
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
 | 
			
		||||
 | 
			
		||||
    payload = {
 | 
			
		||||
        **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
 | 
			
		||||
    }
 | 
			
		||||
    if "metadata" in payload:
 | 
			
		||||
        del payload["metadata"]
 | 
			
		||||
    payload.pop("metadata")
 | 
			
		||||
 | 
			
		||||
    model_id = form_data.model
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
@ -731,148 +744,15 @@ async def generate_chat_completion(
 | 
			
		||||
            if payload.get("options") is None:
 | 
			
		||||
                payload["options"] = {}
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("mirostat", None)
 | 
			
		||||
                and payload["options"].get("mirostat") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["mirostat"] = params.get("mirostat", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("mirostat_eta", None)
 | 
			
		||||
                and payload["options"].get("mirostat_eta") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["mirostat_eta"] = params.get("mirostat_eta", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("mirostat_tau", None)
 | 
			
		||||
                and payload["options"].get("mirostat_tau") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["mirostat_tau"] = params.get("mirostat_tau", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("num_ctx", None)
 | 
			
		||||
                and payload["options"].get("num_ctx") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_ctx"] = params.get("num_ctx", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("num_batch", None)
 | 
			
		||||
                and payload["options"].get("num_batch") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_batch"] = params.get("num_batch", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("num_keep", None)
 | 
			
		||||
                and payload["options"].get("num_keep") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_keep"] = params.get("num_keep", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("repeat_last_n", None)
 | 
			
		||||
                and payload["options"].get("repeat_last_n") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["repeat_last_n"] = params.get("repeat_last_n", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("frequency_penalty", None)
 | 
			
		||||
                and payload["options"].get("frequency_penalty") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["repeat_penalty"] = params.get(
 | 
			
		||||
                    "frequency_penalty", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("temperature", None) is not None
 | 
			
		||||
                and payload["options"].get("temperature") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["temperature"] = params.get("temperature", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("seed", None) is not None
 | 
			
		||||
                and payload["options"].get("seed") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["seed"] = params.get("seed", None)
 | 
			
		||||
 | 
			
		||||
            if params.get("stop", None) and payload["options"].get("stop") is None:
 | 
			
		||||
                payload["options"]["stop"] = (
 | 
			
		||||
                    [
 | 
			
		||||
                        bytes(stop, "utf-8").decode("unicode_escape")
 | 
			
		||||
                        for stop in params["stop"]
 | 
			
		||||
                    ]
 | 
			
		||||
                    if params.get("stop", None)
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if params.get("tfs_z", None) and payload["options"].get("tfs_z") is None:
 | 
			
		||||
                payload["options"]["tfs_z"] = params.get("tfs_z", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("max_tokens", None)
 | 
			
		||||
                and payload["options"].get("max_tokens") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_predict"] = params.get("max_tokens", None)
 | 
			
		||||
 | 
			
		||||
            if params.get("top_k", None) and payload["options"].get("top_k") is None:
 | 
			
		||||
                payload["options"]["top_k"] = params.get("top_k", None)
 | 
			
		||||
 | 
			
		||||
            if params.get("top_p", None) and payload["options"].get("top_p") is None:
 | 
			
		||||
                payload["options"]["top_p"] = params.get("top_p", None)
 | 
			
		||||
 | 
			
		||||
            if params.get("min_p", None) and payload["options"].get("min_p") is None:
 | 
			
		||||
                payload["options"]["min_p"] = params.get("min_p", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("use_mmap", None)
 | 
			
		||||
                and payload["options"].get("use_mmap") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["use_mmap"] = params.get("use_mmap", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("use_mlock", None)
 | 
			
		||||
                and payload["options"].get("use_mlock") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["use_mlock"] = params.get("use_mlock", None)
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                params.get("num_thread", None)
 | 
			
		||||
                and payload["options"].get("num_thread") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_thread"] = params.get("num_thread", None)
 | 
			
		||||
 | 
			
		||||
        system = params.get("system", None)
 | 
			
		||||
        if system:
 | 
			
		||||
            system = prompt_template(
 | 
			
		||||
                system,
 | 
			
		||||
                **(
 | 
			
		||||
                    {
 | 
			
		||||
                        "user_name": user.name,
 | 
			
		||||
                        "user_location": (
 | 
			
		||||
                            user.info.get("location") if user.info else None
 | 
			
		||||
                        ),
 | 
			
		||||
                    }
 | 
			
		||||
                    if user
 | 
			
		||||
                    else {}
 | 
			
		||||
                ),
 | 
			
		||||
            payload["options"] = apply_model_params_to_body_ollama(
 | 
			
		||||
                params, payload["options"]
 | 
			
		||||
            )
 | 
			
		||||
            payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
            if payload.get("messages"):
 | 
			
		||||
                payload["messages"] = add_or_update_system_message(
 | 
			
		||||
                    system, payload["messages"]
 | 
			
		||||
                )
 | 
			
		||||
    if ":" not in payload["model"]:
 | 
			
		||||
        payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
    if url_idx is None:
 | 
			
		||||
        if ":" not in payload["model"]:
 | 
			
		||||
            payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
        if payload["model"] in app.state.MODELS:
 | 
			
		||||
            url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
 | 
			
		||||
        else:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=400,
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    url = get_ollama_url(url_idx, payload["model"])
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
    log.debug(payload)
 | 
			
		||||
 | 
			
		||||
@ -906,83 +786,27 @@ async def generate_openai_chat_completion(
 | 
			
		||||
    url_idx: Optional[int] = None,
 | 
			
		||||
    user=Depends(get_verified_user),
 | 
			
		||||
):
 | 
			
		||||
    form_data = OpenAIChatCompletionForm(**form_data)
 | 
			
		||||
    payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
 | 
			
		||||
    completion_form = OpenAIChatCompletionForm(**form_data)
 | 
			
		||||
    payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
 | 
			
		||||
    payload.pop("metadata")
 | 
			
		||||
 | 
			
		||||
    if "metadata" in payload:
 | 
			
		||||
        del payload["metadata"]
 | 
			
		||||
 | 
			
		||||
    model_id = form_data.model
 | 
			
		||||
    model_id = completion_form.model
 | 
			
		||||
    model_info = Models.get_model_by_id(model_id)
 | 
			
		||||
 | 
			
		||||
    if model_info:
 | 
			
		||||
        if model_info.base_model_id:
 | 
			
		||||
            payload["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        model_info.params = model_info.params.model_dump()
 | 
			
		||||
        params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        if model_info.params:
 | 
			
		||||
            payload["temperature"] = model_info.params.get("temperature", None)
 | 
			
		||||
            payload["top_p"] = model_info.params.get("top_p", None)
 | 
			
		||||
            payload["max_tokens"] = model_info.params.get("max_tokens", None)
 | 
			
		||||
            payload["frequency_penalty"] = model_info.params.get(
 | 
			
		||||
                "frequency_penalty", None
 | 
			
		||||
            )
 | 
			
		||||
            payload["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
            payload["stop"] = (
 | 
			
		||||
                [
 | 
			
		||||
                    bytes(stop, "utf-8").decode("unicode_escape")
 | 
			
		||||
                    for stop in model_info.params["stop"]
 | 
			
		||||
                ]
 | 
			
		||||
                if model_info.params.get("stop", None)
 | 
			
		||||
                else None
 | 
			
		||||
            )
 | 
			
		||||
        if params:
 | 
			
		||||
            payload = apply_model_params_to_body_openai(params, payload)
 | 
			
		||||
            payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
        system = model_info.params.get("system", None)
 | 
			
		||||
    if ":" not in payload["model"]:
 | 
			
		||||
        payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
        if system:
 | 
			
		||||
            system = prompt_template(
 | 
			
		||||
                system,
 | 
			
		||||
                **(
 | 
			
		||||
                    {
 | 
			
		||||
                        "user_name": user.name,
 | 
			
		||||
                        "user_location": (
 | 
			
		||||
                            user.info.get("location") if user.info else None
 | 
			
		||||
                        ),
 | 
			
		||||
                    }
 | 
			
		||||
                    if user
 | 
			
		||||
                    else {}
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            # Check if the payload already has a system message
 | 
			
		||||
            # If not, add a system message to the payload
 | 
			
		||||
            if payload.get("messages"):
 | 
			
		||||
                for message in payload["messages"]:
 | 
			
		||||
                    if message.get("role") == "system":
 | 
			
		||||
                        message["content"] = system + message["content"]
 | 
			
		||||
                        break
 | 
			
		||||
                else:
 | 
			
		||||
                    payload["messages"].insert(
 | 
			
		||||
                        0,
 | 
			
		||||
                        {
 | 
			
		||||
                            "role": "system",
 | 
			
		||||
                            "content": system,
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    if url_idx is None:
 | 
			
		||||
        if ":" not in payload["model"]:
 | 
			
		||||
            payload["model"] = f"{payload['model']}:latest"
 | 
			
		||||
 | 
			
		||||
        if payload["model"] in app.state.MODELS:
 | 
			
		||||
            url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
 | 
			
		||||
        else:
 | 
			
		||||
            raise HTTPException(
 | 
			
		||||
                status_code=400,
 | 
			
		||||
                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 | 
			
		||||
    url = get_ollama_url(url_idx, payload["model"])
 | 
			
		||||
    log.info(f"url: {url}")
 | 
			
		||||
 | 
			
		||||
    return await post_streaming_url(
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,10 @@ from utils.utils import (
 | 
			
		||||
    get_verified_user,
 | 
			
		||||
    get_admin_user,
 | 
			
		||||
)
 | 
			
		||||
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
 | 
			
		||||
from utils.misc import (
 | 
			
		||||
    apply_model_params_to_body_openai,
 | 
			
		||||
    apply_model_system_prompt_to_body,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from config import (
 | 
			
		||||
    SRC_LOG_LEVELS,
 | 
			
		||||
@ -366,7 +369,7 @@ async def generate_chat_completion(
 | 
			
		||||
            payload["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        params = model_info.params.model_dump()
 | 
			
		||||
        payload = apply_model_params_to_body(params, payload)
 | 
			
		||||
        payload = apply_model_params_to_body_openai(params, payload)
 | 
			
		||||
        payload = apply_model_system_prompt_to_body(params, payload, user)
 | 
			
		||||
 | 
			
		||||
    model = app.state.MODELS[payload.get("model")]
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
 | 
			
		||||
from utils.misc import (
 | 
			
		||||
    openai_chat_chunk_message_template,
 | 
			
		||||
    openai_chat_completion_message_template,
 | 
			
		||||
    apply_model_params_to_body,
 | 
			
		||||
    apply_model_params_to_body_openai,
 | 
			
		||||
    apply_model_system_prompt_to_body,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -289,7 +289,7 @@ async def generate_function_chat_completion(form_data, user):
 | 
			
		||||
            form_data["model"] = model_info.base_model_id
 | 
			
		||||
 | 
			
		||||
        params = model_info.params.model_dump()
 | 
			
		||||
        form_data = apply_model_params_to_body(params, form_data)
 | 
			
		||||
        form_data = apply_model_params_to_body_openai(params, form_data)
 | 
			
		||||
        form_data = apply_model_system_prompt_to_body(params, form_data, user)
 | 
			
		||||
 | 
			
		||||
    pipe_id = get_pipe_id(form_data)
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@ from pathlib import Path
 | 
			
		||||
import hashlib
 | 
			
		||||
import re
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from typing import Optional, List, Tuple
 | 
			
		||||
from typing import Optional, List, Tuple, Callable
 | 
			
		||||
import uuid
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# inplace function: form_data is modified
 | 
			
		||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
 | 
			
		||||
def apply_model_params_to_body(
 | 
			
		||||
    params: dict, form_data: dict, mappings: dict[str, Callable]
 | 
			
		||||
) -> dict:
 | 
			
		||||
    if not params:
 | 
			
		||||
        return form_data
 | 
			
		||||
 | 
			
		||||
    mappings = {
 | 
			
		||||
        "temperature": float,
 | 
			
		||||
        "top_p": int,
 | 
			
		||||
        "max_tokens": int,
 | 
			
		||||
        "frequency_penalty": int,
 | 
			
		||||
        "seed": lambda x: x,
 | 
			
		||||
        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for key, cast_func in mappings.items():
 | 
			
		||||
        if (value := params.get(key)) is not None:
 | 
			
		||||
            form_data[key] = cast_func(value)
 | 
			
		||||
@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
 | 
			
		||||
    return form_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
OPENAI_MAPPINGS = {
 | 
			
		||||
    "temperature": float,
 | 
			
		||||
    "top_p": int,
 | 
			
		||||
    "max_tokens": int,
 | 
			
		||||
    "frequency_penalty": int,
 | 
			
		||||
    "seed": lambda x: x,
 | 
			
		||||
    "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# inplace function: form_data is modified
 | 
			
		||||
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
 | 
			
		||||
    return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
 | 
			
		||||
    opts = [
 | 
			
		||||
        "mirostat",
 | 
			
		||||
        "mirostat_eta",
 | 
			
		||||
        "mirostat_tau",
 | 
			
		||||
        "num_ctx",
 | 
			
		||||
        "num_batch",
 | 
			
		||||
        "num_keep",
 | 
			
		||||
        "repeat_last_n",
 | 
			
		||||
        "tfs_z",
 | 
			
		||||
        "top_k",
 | 
			
		||||
        "min_p",
 | 
			
		||||
        "use_mmap",
 | 
			
		||||
        "use_mlock",
 | 
			
		||||
        "num_thread",
 | 
			
		||||
    ]
 | 
			
		||||
    mappings = {i: lambda x: x for i in opts}
 | 
			
		||||
    mappings = {**mappings, **OPENAI_MAPPINGS}
 | 
			
		||||
    return apply_model_params_to_body(params, form_data, mappings)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_gravatar_url(email):
 | 
			
		||||
    # Trim leading and trailing whitespace from
 | 
			
		||||
    # an email address and force all characters
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user