Merge pull request #4295 from michaelpoluektov/refactor-tools

refactor: Refactor OpenAI API to use helper functions, silence LSP/linter warnings
This commit is contained in:
Timothy Jaeryang Baek 2024-08-04 14:17:52 +02:00 committed by GitHub
commit 91851114e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 158 additions and 238 deletions

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.responses import StreamingResponse, FileResponse
import requests
import aiohttp
@ -12,16 +12,12 @@ from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_verified_user,
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from utils.misc import add_or_update_system_message
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
from config import (
SRC_LOG_LEVELS,
@ -34,7 +30,7 @@ from config import (
MODEL_FILTER_LIST,
AppConfig,
)
from typing import List, Optional
from typing import List, Optional, Literal, overload
import hashlib
@ -69,8 +65,6 @@ app.state.MODELS = {}
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
@ -234,64 +228,68 @@ def merge_models_lists(model_lists):
return merged_list
async def get_all_models(raw: bool = False):
def is_openai_api_disabled():
api_keys = app.state.config.OPENAI_API_KEYS
no_keys = len(api_keys) == 1 and api_keys[0] == ""
return no_keys or not app.state.config.ENABLE_OPENAI_API
async def get_all_models_raw() -> list:
if is_openai_api_disabled():
return []
# Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys
if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
return responses
@overload
async def get_all_models(raw: Literal[True]) -> list: ...
@overload
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
async def get_all_models(raw=False) -> dict[str, list] | list:
log.info("get_all_models()")
if is_openai_api_disabled():
return [] if raw else {"data": []}
if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
responses = await get_all_models_raw()
if raw:
return responses
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
def extract_data(response):
if response and "data" in response:
return response["data"]
if isinstance(response, list):
return response
return None
responses = await asyncio.gather(*tasks)
log.debug(f"get_all_models:responses() {responses}")
models = {"data": merge_models_lists(map(extract_data, responses))}
if raw:
return responses
models = {
"data": merge_models_lists(
list(
map(
lambda response: (
response["data"]
if (response and "data" in response)
else (response if isinstance(response, list) else None)
),
responses,
)
)
)
}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@ -299,7 +297,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None:
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
@ -340,7 +338,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
@ -358,8 +356,7 @@ async def generate_chat_completion(
):
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
payload.pop("metadata")
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
@ -368,70 +365,9 @@ async def generate_chat_completion(
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if (
model_info.params.get("temperature", None) is not None
and payload.get("temperature") is None
):
payload["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None) and payload.get("top_p") is None:
payload["top_p"] = int(model_info.params.get("top_p", None))
if (
model_info.params.get("max_tokens", None)
and payload.get("max_tokens") is None
):
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if (
model_info.params.get("frequency_penalty", None)
and payload.get("frequency_penalty") is None
):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if (
model_info.params.get("seed", None) is not None
and payload.get("seed") is None
):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None) and payload.get("stop") is None:
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
)
else:
pass
params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
@ -444,13 +380,6 @@ async def generate_chat_completion(
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
@ -506,7 +435,7 @@ async def generate_chat_completion(
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
@ -569,7 +498,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:

View File

@ -44,23 +44,26 @@ async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
return
if auth and "token" in auth:
data = decode_token(auth["token"])
data = decode_token(auth["token"])
if data is None or "id" not in data:
return
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
user = Users.get_user_by_id(data["id"])
if not user:
return
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")

View File

@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
add_or_update_system_message,
apply_model_params_to_body,
apply_model_system_prompt_to_body,
)
from utils.task import prompt_template
from config import (
@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
return params
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)

View File

@ -1,12 +1,8 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import Depends, HTTPException, status, Request
from typing import List, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
@ -69,7 +64,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
if toolkit is None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
@ -98,7 +93,7 @@ async def create_new_toolkit(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@ -170,7 +165,7 @@ async def update_toolkit_by_id(
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(
@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
toolkit_module, _ = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT(str(e)),
)
else:
raise HTTPException(

View File

@ -957,7 +957,7 @@ async def get_all_models():
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
if custom_model.base_model_id is None:
for model in models:
if (
custom_model.id == model["id"]
@ -1656,13 +1656,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
@app.get("/api/pipelines/list")
async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True)
responses = await get_openai_models(raw = True)
print(responses)
urlIdxs = [
idx
for idx, response in enumerate(responses)
if response != None and "pipelines" in response
if response is not None and "pipelines" in response
]
return {
@ -1723,7 +1723,7 @@ async def upload_pipeline(
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
@ -1769,7 +1769,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user))
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
@ -1811,7 +1811,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
@ -1844,7 +1844,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
@ -1859,7 +1859,6 @@ async def get_pipeline_valves(
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@ -1898,8 +1897,6 @@ async def get_pipeline_valves_spec(
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@ -1922,7 +1919,7 @@ async def get_pipeline_valves_spec(
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
@ -1938,8 +1935,6 @@ async def update_pipeline_valves(
form_data: dict,
user=Depends(get_admin_user),
):
models = await get_all_models()
r = None
try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@ -1967,7 +1962,7 @@ async def update_pipeline_valves(
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
except Exception:
pass
raise HTTPException(
@ -2068,7 +2063,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
@app.get("/api/version")
async def get_app_config():
async def get_app_version():
return {
"version": VERSION,
}
@ -2091,7 +2086,7 @@ async def get_app_latest_release_version():
latest_version = data["tag_name"]
return {"current": VERSION, "latest": latest_version[1:]}
except aiohttp.ClientError as e:
except aiohttp.ClientError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,

View File

@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
import uuid
import time
from utils.task import prompt_template
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
for message in reversed(messages):
@ -112,6 +114,47 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict:
return template
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters

View File

@ -6,7 +6,7 @@ from typing import Optional
def prompt_template(
template: str, user_name: str = None, user_location: str = None
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
) -> str:
# Get the current date
current_date = datetime.now()
@ -83,7 +83,6 @@ def title_generation_template(
def search_query_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)

View File

@ -1,15 +1,12 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session
from apps.webui.models.users import Users
from pydantic import BaseModel
from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
import uuid
import logging
@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded
except Exception as e:
except Exception:
return None
@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
try:
scheme, credentials = auth_header.split(" ")
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except:
except Exception:
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
@ -96,7 +93,7 @@ def get_current_user(
# auth by jwt token
data = decode_token(token)
if data != None and "id" in data:
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(