mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: re-use utils.misc
This commit is contained in:
@@ -2,7 +2,7 @@ from pathlib import Path
|
||||
import hashlib
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import Optional, List, Tuple, Callable
|
||||
import uuid
|
||||
import time
|
||||
|
||||
@@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
|
||||
def apply_model_params_to_body(
|
||||
params: dict, form_data: dict, mappings: dict[str, Callable]
|
||||
) -> dict:
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
@@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
|
||||
return form_data
|
||||
|
||||
|
||||
OPENAI_MAPPINGS = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
return apply_model_params_to_body(params, form_data, OPENAI_MAPPINGS)
|
||||
|
||||
|
||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
opts = [
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"num_ctx",
|
||||
"num_batch",
|
||||
"num_keep",
|
||||
"repeat_last_n",
|
||||
"tfs_z",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"use_mmap",
|
||||
"use_mlock",
|
||||
"num_thread",
|
||||
]
|
||||
mappings = {i: lambda x: x for i in opts}
|
||||
mappings = {**mappings, **OPENAI_MAPPINGS}
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
# Trim leading and trailing whitespace from
|
||||
# an email address and force all characters
|
||||
|
||||
Reference in New Issue
Block a user