mirror of
https://github.com/open-webui/open-webui
synced 2024-11-07 00:59:52 +00:00
fix: pipe custom model
This commit is contained in:
parent
740b6f5c17
commit
67c2ab006d
@ -19,8 +19,13 @@ from apps.webui.routers import (
|
||||
functions,
|
||||
)
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.models import Models
|
||||
|
||||
from apps.webui.utils import load_function_module_by_id
|
||||
|
||||
from utils.misc import stream_message_template
|
||||
from utils.task import prompt_template
|
||||
|
||||
|
||||
from config import (
|
||||
WEBUI_BUILD_HASH,
|
||||
@ -186,6 +191,77 @@ async def get_pipe_models():
|
||||
|
||||
|
||||
async def generate_function_chat_completion(form_data, user):
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
form_data["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
if model_info.params.get("temperature", None) is not None:
|
||||
form_data["temperature"] = float(model_info.params.get("temperature"))
|
||||
|
||||
if model_info.params.get("top_p", None):
|
||||
form_data["top_p"] = int(model_info.params.get("top_p", None))
|
||||
|
||||
if model_info.params.get("max_tokens", None):
|
||||
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
||||
|
||||
if model_info.params.get("frequency_penalty", None):
|
||||
form_data["frequency_penalty"] = int(
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
)
|
||||
|
||||
if model_info.params.get("seed", None):
|
||||
form_data["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None):
|
||||
form_data["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
system = model_info.params.get("system", None)
|
||||
if system:
|
||||
system = prompt_template(
|
||||
system,
|
||||
**(
|
||||
{
|
||||
"user_name": user.name,
|
||||
"user_location": (
|
||||
user.info.get("location") if user.info else None
|
||||
),
|
||||
}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if form_data.get("messages"):
|
||||
for message in form_data["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = system + message["content"]
|
||||
break
|
||||
else:
|
||||
form_data["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": system,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
async def job():
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
|
@ -975,12 +975,16 @@ async def get_all_models():
|
||||
model["info"] = custom_model.model_dump()
|
||||
else:
|
||||
owned_by = "openai"
|
||||
pipe = None
|
||||
|
||||
for model in models:
|
||||
if (
|
||||
custom_model.base_model_id == model["id"]
|
||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||
):
|
||||
owned_by = model["owned_by"]
|
||||
if "pipe" in model:
|
||||
pipe = model["pipe"]
|
||||
break
|
||||
|
||||
models.append(
|
||||
@ -992,11 +996,11 @@ async def get_all_models():
|
||||
"owned_by": owned_by,
|
||||
"info": custom_model.model_dump(),
|
||||
"preset": True,
|
||||
**({"pipe": pipe} if pipe is not None else {}),
|
||||
}
|
||||
)
|
||||
|
||||
app.state.MODELS = {model["id"]: model for model in models}
|
||||
|
||||
webui_app.state.MODELS = app.state.MODELS
|
||||
|
||||
return models
|
||||
|
Loading…
Reference in New Issue
Block a user