refac: apps/openai/main.py and utils

This commit is contained in:
Michael Poluektov
2024-08-03 14:24:26 +01:00
parent 774defd184
commit 12c21fac22
8 changed files with 148 additions and 230 deletions

View File

@@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
add_or_update_system_message,
apply_model_params_to_body,
apply_model_system_prompt_to_body,
)
from utils.task import prompt_template
from config import (
@@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
return params
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)

View File

@@ -1,12 +1,8 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import Depends, HTTPException, status, Request
from typing import List, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
@@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
@@ -69,7 +64,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
if toolkit is None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
@@ -98,7 +93,7 @@ async def create_new_toolkit(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -170,7 +165,7 @@ async def update_toolkit_by_id(
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
@@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
@@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
@@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
@@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
@@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(