refac: re-use utils.misc

This commit is contained in:
Michael Poluektov 2024-08-06 11:31:45 +01:00
parent 44c781f414
commit fc31267a54
4 changed files with 85 additions and 229 deletions

View File

@ -44,7 +44,13 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256, add_or_update_system_message from utils.misc import (
apply_model_params_to_body_ollama,
calculate_sha256,
add_or_update_system_message,
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@ -699,6 +705,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
def get_ollama_url(url_idx: Optional[int], model: str):
if url_idx is None:
if model not in app.state.MODELS:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url_idx = random.choice(app.state.MODELS[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
return url
@app.post("/api/chat") @app.post("/api/chat")
@app.post("/api/chat/{url_idx}") @app.post("/api/chat/{url_idx}")
async def generate_chat_completion( async def generate_chat_completion(
@ -706,17 +724,12 @@ async def generate_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.debug( log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
payload = { payload = {
**form_data.model_dump(exclude_none=True, exclude=["metadata"]), **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
} }
if "metadata" in payload: payload.pop("metadata")
del payload["metadata"]
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
@ -731,148 +744,15 @@ async def generate_chat_completion(
if payload.get("options") is None: if payload.get("options") is None:
payload["options"] = {} payload["options"] = {}
if ( payload["options"] = apply_model_params_to_body_ollama(
params.get("mirostat", None) params, payload["options"]
and payload["options"].get("mirostat") is None
):
payload["options"]["mirostat"] = params.get("mirostat", None)
if (
params.get("mirostat_eta", None)
and payload["options"].get("mirostat_eta") is None
):
payload["options"]["mirostat_eta"] = params.get("mirostat_eta", None)
if (
params.get("mirostat_tau", None)
and payload["options"].get("mirostat_tau") is None
):
payload["options"]["mirostat_tau"] = params.get("mirostat_tau", None)
if (
params.get("num_ctx", None)
and payload["options"].get("num_ctx") is None
):
payload["options"]["num_ctx"] = params.get("num_ctx", None)
if (
params.get("num_batch", None)
and payload["options"].get("num_batch") is None
):
payload["options"]["num_batch"] = params.get("num_batch", None)
if (
params.get("num_keep", None)
and payload["options"].get("num_keep") is None
):
payload["options"]["num_keep"] = params.get("num_keep", None)
if (
params.get("repeat_last_n", None)
and payload["options"].get("repeat_last_n") is None
):
payload["options"]["repeat_last_n"] = params.get("repeat_last_n", None)
if (
params.get("frequency_penalty", None)
and payload["options"].get("frequency_penalty") is None
):
payload["options"]["repeat_penalty"] = params.get(
"frequency_penalty", None
)
if (
params.get("temperature", None) is not None
and payload["options"].get("temperature") is None
):
payload["options"]["temperature"] = params.get("temperature", None)
if (
params.get("seed", None) is not None
and payload["options"].get("seed") is None
):
payload["options"]["seed"] = params.get("seed", None)
if params.get("stop", None) and payload["options"].get("stop") is None:
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in params["stop"]
]
if params.get("stop", None)
else None
)
if params.get("tfs_z", None) and payload["options"].get("tfs_z") is None:
payload["options"]["tfs_z"] = params.get("tfs_z", None)
if (
params.get("max_tokens", None)
and payload["options"].get("max_tokens") is None
):
payload["options"]["num_predict"] = params.get("max_tokens", None)
if params.get("top_k", None) and payload["options"].get("top_k") is None:
payload["options"]["top_k"] = params.get("top_k", None)
if params.get("top_p", None) and payload["options"].get("top_p") is None:
payload["options"]["top_p"] = params.get("top_p", None)
if params.get("min_p", None) and payload["options"].get("min_p") is None:
payload["options"]["min_p"] = params.get("min_p", None)
if (
params.get("use_mmap", None)
and payload["options"].get("use_mmap") is None
):
payload["options"]["use_mmap"] = params.get("use_mmap", None)
if (
params.get("use_mlock", None)
and payload["options"].get("use_mlock") is None
):
payload["options"]["use_mlock"] = params.get("use_mlock", None)
if (
params.get("num_thread", None)
and payload["options"].get("num_thread") is None
):
payload["options"]["num_thread"] = params.get("num_thread", None)
system = params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
) )
payload = apply_model_system_prompt_to_body(params, payload, user)
if payload.get("messages"): if ":" not in payload["model"]:
payload["messages"] = add_or_update_system_message( payload["model"] = f"{payload['model']}:latest"
system, payload["messages"]
)
if url_idx is None: url = get_ollama_url(url_idx, payload["model"])
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(payload) log.debug(payload)
@ -906,83 +786,27 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
form_data = OpenAIChatCompletionForm(**form_data) completion_form = OpenAIChatCompletionForm(**form_data)
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
payload.pop("metadata")
if "metadata" in payload: model_id = completion_form.model
del payload["metadata"]
model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() params = model_info.params.model_dump()
if model_info.params: if params:
payload["temperature"] = model_info.params.get("temperature", None) payload = apply_model_params_to_body_openai(params, payload)
payload["top_p"] = model_info.params.get("top_p", None) payload = apply_model_system_prompt_to_body(params, payload, user)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None) if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if system: url = get_ollama_url(url_idx, payload["model"])
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if url_idx is None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
return await post_streaming_url( return await post_streaming_url(

View File

@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body from utils.misc import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
@ -366,7 +369,7 @@ async def generate_chat_completion(
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")] model = app.state.MODELS[payload.get("model")]

View File

@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import ( from utils.misc import (
openai_chat_chunk_message_template, openai_chat_chunk_message_template,
openai_chat_completion_message_template, openai_chat_completion_message_template,
apply_model_params_to_body, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
@ -289,7 +289,7 @@ async def generate_function_chat_completion(form_data, user):
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
form_data = apply_model_params_to_body(params, form_data) form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user) form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)

View File

@ -2,7 +2,7 @@ from pathlib import Path
import hashlib import hashlib
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List, Tuple from typing import Optional, List, Tuple, Callable
import uuid import uuid
import time import time
@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified # inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict: def apply_model_params_to_body(
params: dict, form_data: dict, mappings: dict[str, Callable]
) -> dict:
if not params: if not params:
return form_data return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items(): for key, cast_func in mappings.items():
if (value := params.get(key)) is not None: if (value := params.get(key)) is not None:
form_data[key] = cast_func(value) form_data[key] = cast_func(value)
@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
return form_data return form_data
OPENAI_MAPPINGS = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
]
mappings = {i: lambda x: x for i in opts}
mappings = {**mappings, **OPENAI_MAPPINGS}
return apply_model_params_to_body(params, form_data, mappings)
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters