refac: apps/openai/main.py and utils

This commit is contained in:
Michael Poluektov
2024-08-03 14:24:26 +01:00
parent 774defd184
commit 12c21fac22
8 changed files with 148 additions and 230 deletions

View File

@@ -1,6 +1,6 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.responses import StreamingResponse, FileResponse
import requests
import aiohttp
@@ -12,16 +12,12 @@ from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_verified_user,
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from utils.misc import add_or_update_system_message
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
from config import (
SRC_LOG_LEVELS,
@@ -69,8 +65,6 @@ app.state.MODELS = {}
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
@@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
@@ -234,64 +228,58 @@ def merge_models_lists(model_lists):
return merged_list
async def get_all_models(raw: bool = False):
def is_openai_api_disabled():
api_keys = app.state.config.OPENAI_API_KEYS
no_keys = len(api_keys) == 1 and api_keys[0] == ""
return no_keys or not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list:
if is_openai_api_disabled():
return []
# Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys
if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
return responses
async def get_all_models() -> dict[str, list]:
log.info("get_all_models()")
if is_openai_api_disabled():
return {"data": []}
if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
responses = await get_all_models_raw()
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
def extract_data(response):
if response and "data" in response:
return response["data"]
if isinstance(response, list):
return response
return None
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
models = {"data": merge_models_lists(map(extract_data, responses))}
if raw:
return responses
models = {
"data": merge_models_lists(
list(
map(
lambda response: (
response["data"]
if (response and "data" in response)
else (response if isinstance(response, list) else None)
),
responses,
)
)
)
}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@@ -299,7 +287,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None:
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
@@ -340,7 +328,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
@@ -358,8 +346,7 @@ async def generate_chat_completion(
):
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
payload.pop("metadata")
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
@@ -368,70 +355,9 @@ async def generate_chat_completion(
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if (
model_info.params.get("temperature", None) is not None
and payload.get("temperature") is None
):
payload["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None) and payload.get("top_p") is None:
payload["top_p"] = int(model_info.params.get("top_p", None))
if (
model_info.params.get("max_tokens", None)
and payload.get("max_tokens") is None
):
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if (
model_info.params.get("frequency_penalty", None)
and payload.get("frequency_penalty") is None
):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if (
model_info.params.get("seed", None) is not None
and payload.get("seed") is None
):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None) and payload.get("stop") is None:
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
)
else:
pass
params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
@@ -506,7 +432,7 @@ async def generate_chat_completion(
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
@@ -569,7 +495,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:

View File

@@ -44,23 +44,26 @@ async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
return
if auth and "token" in auth:
data = decode_token(auth["token"])
data = decode_token(auth["token"])
if data is None or "id" not in data:
return
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
user = Users.get_user_by_id(data["id"])
if not user:
return
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")

View File

@@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
add_or_update_system_message,
apply_model_params_to_body,
apply_model_system_prompt_to_body,
)
from utils.task import prompt_template
from config import (
@@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
return params
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)

View File

@@ -1,12 +1,8 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import Depends, HTTPException, status, Request
from typing import List, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
@@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
@@ -69,7 +64,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
if toolkit is None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
@@ -98,7 +93,7 @@ async def create_new_toolkit(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -170,7 +165,7 @@ async def update_toolkit_by_id(
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
@@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
@@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
@@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
@@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
@@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(