mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	refac
This commit is contained in:
		
							parent
							
								
									ca5f1c1efb
								
							
						
					
					
						commit
						3d0f457306
					
				@ -743,52 +743,85 @@ async def generate_chat_completion(
 | 
			
		||||
        model_info.params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        if model_info.params:
 | 
			
		||||
            payload["options"] = {}
 | 
			
		||||
            if payload.get("options") is None:
 | 
			
		||||
                payload["options"] = {}
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("mirostat", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("mirostat", None)
 | 
			
		||||
                and payload["options"].get("mirostat") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("mirostat_eta", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("mirostat_eta", None)
 | 
			
		||||
                and payload["options"].get("mirostat_eta") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["mirostat_eta"] = model_info.params.get(
 | 
			
		||||
                    "mirostat_eta", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("mirostat_tau", None):
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("mirostat_tau", None)
 | 
			
		||||
                and payload["options"].get("mirostat_tau") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["mirostat_tau"] = model_info.params.get(
 | 
			
		||||
                    "mirostat_tau", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("num_ctx", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("num_ctx", None)
 | 
			
		||||
                and payload["options"].get("num_ctx") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("num_batch", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("num_batch", None)
 | 
			
		||||
                and payload["options"].get("num_batch") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_batch"] = model_info.params.get(
 | 
			
		||||
                    "num_batch", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("num_keep", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("num_keep", None)
 | 
			
		||||
                and payload["options"].get("num_keep") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("repeat_last_n", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("repeat_last_n", None)
 | 
			
		||||
                and payload["options"].get("repeat_last_n") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["repeat_last_n"] = model_info.params.get(
 | 
			
		||||
                    "repeat_last_n", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("frequency_penalty", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("frequency_penalty", None)
 | 
			
		||||
                and payload["options"].get("frequency_penalty") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["repeat_penalty"] = model_info.params.get(
 | 
			
		||||
                    "frequency_penalty", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("temperature", None) is not None:
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("temperature", None)
 | 
			
		||||
                and payload["options"].get("temperature") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["temperature"] = model_info.params.get(
 | 
			
		||||
                    "temperature", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("seed", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("seed", None)
 | 
			
		||||
                and payload["options"].get("seed") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("stop", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("stop", None)
 | 
			
		||||
                and payload["options"].get("stop") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["stop"] = (
 | 
			
		||||
                    [
 | 
			
		||||
                        bytes(stop, "utf-8").decode("unicode_escape")
 | 
			
		||||
@ -798,37 +831,56 @@ async def generate_chat_completion(
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("tfs_z", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("tfs_z", None)
 | 
			
		||||
                and payload["options"].get("tfs_z") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("max_tokens", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("max_tokens", None)
 | 
			
		||||
                and payload["options"].get("max_tokens") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_predict"] = model_info.params.get(
 | 
			
		||||
                    "max_tokens", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("top_k", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("top_k", None)
 | 
			
		||||
                and payload["options"].get("top_k") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["top_k"] = model_info.params.get("top_k", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("top_p", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("top_p", None)
 | 
			
		||||
                and payload["options"].get("top_p") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["top_p"] = model_info.params.get("top_p", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("use_mmap", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("use_mmap", None)
 | 
			
		||||
                and payload["options"].get("use_mmap") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("use_mlock", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("use_mlock", None)
 | 
			
		||||
                and payload["options"].get("use_mlock") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["use_mlock"] = model_info.params.get(
 | 
			
		||||
                    "use_mlock", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("num_thread", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("num_thread", None)
 | 
			
		||||
                and payload["options"].get("num_thread") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["options"]["num_thread"] = model_info.params.get(
 | 
			
		||||
                    "num_thread", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        system = model_info.params.get("system", None)
 | 
			
		||||
        if system:
 | 
			
		||||
            # Check if the payload already has a system message
 | 
			
		||||
            # If not, add a system message to the payload
 | 
			
		||||
            system = prompt_template(
 | 
			
		||||
                system,
 | 
			
		||||
                **(
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,7 @@ from utils.utils import (
 | 
			
		||||
    get_admin_user,
 | 
			
		||||
)
 | 
			
		||||
from utils.task import prompt_template
 | 
			
		||||
from utils.misc import add_or_update_system_message
 | 
			
		||||
 | 
			
		||||
from config import (
 | 
			
		||||
    SRC_LOG_LEVELS,
 | 
			
		||||
@ -370,24 +371,33 @@ async def generate_chat_completion(
 | 
			
		||||
        model_info.params = model_info.params.model_dump()
 | 
			
		||||
 | 
			
		||||
        if model_info.params:
 | 
			
		||||
            if model_info.params.get("temperature", None) is not None:
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("temperature", None)
 | 
			
		||||
                and payload.get("temperature") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["temperature"] = float(model_info.params.get("temperature"))
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("top_p", None):
 | 
			
		||||
            if model_info.params.get("top_p", None) and payload.get("top_p") is None:
 | 
			
		||||
                payload["top_p"] = int(model_info.params.get("top_p", None))
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("max_tokens", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("max_tokens", None)
 | 
			
		||||
                and payload.get("max_tokens") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("frequency_penalty", None):
 | 
			
		||||
            if (
 | 
			
		||||
                model_info.params.get("frequency_penalty", None)
 | 
			
		||||
                and payload.get("frequency_penalty") is None
 | 
			
		||||
            ):
 | 
			
		||||
                payload["frequency_penalty"] = int(
 | 
			
		||||
                    model_info.params.get("frequency_penalty", None)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("seed", None):
 | 
			
		||||
            if model_info.params.get("seed", None) and payload.get("seed") is None:
 | 
			
		||||
                payload["seed"] = model_info.params.get("seed", None)
 | 
			
		||||
 | 
			
		||||
            if model_info.params.get("stop", None):
 | 
			
		||||
            if model_info.params.get("stop", None) and payload.get("stop") is None:
 | 
			
		||||
                payload["stop"] = (
 | 
			
		||||
                    [
 | 
			
		||||
                        bytes(stop, "utf-8").decode("unicode_escape")
 | 
			
		||||
@ -412,21 +422,10 @@ async def generate_chat_completion(
 | 
			
		||||
                    else {}
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            # Check if the payload already has a system message
 | 
			
		||||
            # If not, add a system message to the payload
 | 
			
		||||
            if payload.get("messages"):
 | 
			
		||||
                for message in payload["messages"]:
 | 
			
		||||
                    if message.get("role") == "system":
 | 
			
		||||
                        message["content"] = system + message["content"]
 | 
			
		||||
                        break
 | 
			
		||||
                else:
 | 
			
		||||
                    payload["messages"].insert(
 | 
			
		||||
                        0,
 | 
			
		||||
                        {
 | 
			
		||||
                            "role": "system",
 | 
			
		||||
                            "content": system,
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
                payload["messages"] = add_or_update_system_message(
 | 
			
		||||
                    system, payload["messages"]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user