This commit is contained in:
Timothy J. Baek 2024-09-07 03:09:57 +01:00
parent 8d6a424604
commit ff46fe2b4a
5 changed files with 111 additions and 88 deletions

View File

@ -28,11 +28,15 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from open_webui.utils.misc import ( from open_webui.utils.misc import (
calculate_sha256,
)
from open_webui.utils.payload import (
apply_model_params_to_body_ollama, apply_model_params_to_body_ollama,
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
calculate_sha256,
) )
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user

View File

@ -26,10 +26,13 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from open_webui.utils.misc import (
from open_webui.utils.payload import (
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -51,11 +51,15 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.utils.misc import ( from open_webui.utils.misc import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
openai_chat_chunk_message_template, openai_chat_chunk_message_template,
openai_chat_completion_message_template, openai_chat_completion_message_template,
) )
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
app = FastAPI() app = FastAPI()

View File

@ -6,7 +6,14 @@ from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
from open_webui.utils.task import prompt_template
def get_messages_content(messages: list[dict]) -> str:
return "\n".join(
[
f"{message['role'].upper()}: {get_content_from_message(message)}"
for message in messages
]
)
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]: def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
@ -30,7 +37,6 @@ def get_last_user_message(messages: list[dict]) -> Optional[str]:
message = get_last_user_message_item(messages) message = get_last_user_message_item(messages)
if message is None: if message is None:
return None return None
return get_content_from_message(message) return get_content_from_message(message)
@ -114,88 +120,6 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict:
return template return template
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body(
params: dict, form_data: dict, mappings: dict[str, Callable]
) -> dict:
if not params:
return form_data
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
return apply_model_params_to_body(params, form_data, mappings)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"temperature",
"top_p",
"seed",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
"num_gpu",
]
mappings = {i: lambda x: x for i in opts}
form_data = apply_model_params_to_body(params, form_data, mappings)
name_differences = {
"max_tokens": "num_predict",
"frequency_penalty": "repeat_penalty",
}
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
return form_data
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters

View File

@ -0,0 +1,88 @@
from open_webui.utils.task import prompt_template
from open_webui.utils.misc import (
add_or_update_system_message,
)
from typing import Callable, Optional
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body(
params: dict, form_data: dict, mappings: dict[str, Callable]
) -> dict:
if not params:
return form_data
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
return apply_model_params_to_body(params, form_data, mappings)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"temperature",
"top_p",
"seed",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
"num_gpu",
]
mappings = {i: lambda x: x for i in opts}
form_data = apply_model_params_to_body(params, form_data, mappings)
name_differences = {
"max_tokens": "num_predict",
"frequency_penalty": "repeat_penalty",
}
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
return form_data