mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: apps/openai/main.py and utils
This commit is contained in:
@@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
|
||||
from utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
add_or_update_system_message,
|
||||
apply_model_params_to_body,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
from utils.task import prompt_template
|
||||
|
||||
|
||||
from config import (
|
||||
@@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
|
||||
return params
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
system = prompt_template(system, **template_params)
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
async def generate_function_chat_completion(form_data, user):
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
@@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
|
||||
from utils.tools import get_tools_specs
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from importlib import util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -69,7 +64,7 @@ async def create_new_toolkit(
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
toolkit = Tools.get_tool_by_id(form_data.id)
|
||||
if toolkit == None:
|
||||
if toolkit is None:
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
|
||||
try:
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
@@ -98,7 +93,7 @@ async def create_new_toolkit(
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -170,7 +165,7 @@ async def update_toolkit_by_id(
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
|
||||
|
||||
@@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "Valves"):
|
||||
@@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "Valves"):
|
||||
@@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
@@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
@@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
||||
Reference in New Issue
Block a user